mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 17:41:52 +08:00
[Backend] refactor paddle custom ops -> fastdeploy::paddle_custom_ops (#2101)
* [cmake] upgrade windows paddle inference -> 2.5.0 * [cmake] upgrade windows paddle inference -> 2.5.0 * fix paddle custom ops bug on windows * [Backend] refactor paddle custom ops
This commit is contained in:
@@ -64,6 +64,12 @@ DEFINE_string(optimized_model_dir, "",
|
||||
DEFINE_bool(collect_trt_shape_by_device, false,
|
||||
"Optional, whether collect trt shape by device. "
|
||||
"default false.");
|
||||
DEFINE_double(custom_tensor_value, 1.0,
|
||||
"Optional, set the value for fd tensor, "
|
||||
"default 1.0");
|
||||
DEFINE_bool(collect_trt_shape_by_custom_tensor_value, false,
|
||||
"Optional, whether collect trt shape by custom tensor value. "
|
||||
"default false.");
|
||||
|
||||
#if defined(ENABLE_BENCHMARK)
|
||||
static std::vector<int64_t> GetInt64Shape(const std::vector<int>& shape) {
|
||||
@@ -208,6 +214,23 @@ static void RuntimeProfiling(int argc, char* argv[]) {
|
||||
for (int i = 0; i < input_shapes.size(); ++i) {
|
||||
option.trt_option.SetShape(input_names[i], trt_shapes[i * 3],
|
||||
trt_shapes[i * 3 + 1], trt_shapes[i * 3 + 2]);
|
||||
// Set custom input data for collect trt shapes
|
||||
if (FLAGS_collect_trt_shape_by_custom_tensor_value) {
|
||||
int min_shape_num = std::accumulate(trt_shapes[i * 3].begin(),
|
||||
trt_shapes[i * 3].end(), 1,
|
||||
std::multiplies<int>());
|
||||
int opt_shape_num = std::accumulate(trt_shapes[i * 3 + 1].begin(),
|
||||
trt_shapes[i * 3 + 1].end(), 1,
|
||||
std::multiplies<int>());
|
||||
int max_shape_num = std::accumulate(trt_shapes[i * 3 + 2].begin(),
|
||||
trt_shapes[i * 3 + 2].end(), 1,
|
||||
std::multiplies<int>());
|
||||
std::vector<float> min_input_data(min_shape_num, FLAGS_custom_tensor_value);
|
||||
std::vector<float> opt_input_data(opt_shape_num, FLAGS_custom_tensor_value);
|
||||
std::vector<float> max_input_data(max_shape_num, FLAGS_custom_tensor_value);
|
||||
option.trt_option.SetInputData(input_names[i], min_input_data,
|
||||
opt_input_data, max_input_data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,8 +255,9 @@ static void RuntimeProfiling(int argc, char* argv[]) {
|
||||
// Feed inputs, all values set as 1.
|
||||
std::vector<fastdeploy::FDTensor> inputs(runtime.NumInputs());
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
fastdeploy::function::Full(1, GetInt64Shape(input_shapes[i]), &inputs[i],
|
||||
input_dtypes[i]);
|
||||
fastdeploy::function::Full(
|
||||
FLAGS_custom_tensor_value, GetInt64Shape(input_shapes[i]),
|
||||
&inputs[i], input_dtypes[i]);
|
||||
inputs[i].name = input_names[i];
|
||||
}
|
||||
|
||||
|
@@ -21,6 +21,7 @@ add_executable(benchmark_structurev2_table ${PROJECT_SOURCE_DIR}/benchmark_struc
|
||||
add_executable(benchmark_structurev2_layout ${PROJECT_SOURCE_DIR}/benchmark_structurev2_layout.cc)
|
||||
add_executable(benchmark_ppshituv2_rec ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_rec.cc)
|
||||
add_executable(benchmark_ppshituv2_det ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_det.cc)
|
||||
add_executable(benchmark_pp3d_centerpoint ${PROJECT_SOURCE_DIR}/benchmark_pp3d_centerpoint.cc)
|
||||
|
||||
if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
|
||||
target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
@@ -33,6 +34,7 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
|
||||
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags pthread)
|
||||
else()
|
||||
target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_ppcls ${FASTDEPLOY_LIBS} gflags)
|
||||
@@ -44,6 +46,7 @@ else()
|
||||
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags)
|
||||
target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags)
|
||||
endif()
|
||||
# only for Android ADB test
|
||||
if(ANDROID)
|
||||
|
100
benchmark/paddlex/benchmark_pp3d_centerpoint.cc
Normal file
100
benchmark/paddlex/benchmark_pp3d_centerpoint.cc
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "flags.h"
|
||||
#include "macros.h"
|
||||
#include "option.h"
|
||||
|
||||
namespace vision = fastdeploy::vision;
|
||||
namespace benchmark = fastdeploy::benchmark;
|
||||
|
||||
|
||||
static bool ReadTestPoint(const std::string &file_path,
|
||||
std::vector<float> &data) {
|
||||
int with_timelag = 0;
|
||||
int64_t num_point_dim = 5;
|
||||
std::ifstream file_in(file_path, std::ios::in | std::ios::binary);
|
||||
|
||||
if (!file_in) {
|
||||
std::cout << "Failed to read file: " << file_path << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::streampos file_size;
|
||||
file_in.seekg(0, std::ios::end);
|
||||
file_size = file_in.tellg();
|
||||
file_in.seekg(0, std::ios::beg);
|
||||
|
||||
data.resize(file_size / sizeof(float));
|
||||
|
||||
file_in.read(reinterpret_cast<char *>(data.data()), file_size);
|
||||
file_in.close();
|
||||
|
||||
if (file_size / sizeof(float) % num_point_dim != 0) {
|
||||
std::cout << "Loaded file size (" << file_size
|
||||
<< ") is not evenly divisible by num_point_dim (" << num_point_dim
|
||||
<< ")\n";
|
||||
return false;
|
||||
}
|
||||
size_t num_points = file_size / sizeof(float) / num_point_dim;
|
||||
if (!with_timelag && num_point_dim == 5 || num_point_dim > 5) {
|
||||
for (int64_t i = 0; i < num_points; ++i) {
|
||||
data[i * num_point_dim + 4] = 0.;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION)
|
||||
// Initialization
|
||||
auto option = fastdeploy::RuntimeOption();
|
||||
if (!CreateRuntimeOption(&option, argc, argv, true)) {
|
||||
return -1;
|
||||
}
|
||||
std::string point_dir = FLAGS_image;
|
||||
std::unordered_map<std::string, std::string> config_info;
|
||||
benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path,
|
||||
&config_info);
|
||||
std::string model_name, params_name, config_name;
|
||||
auto model_format = fastdeploy::ModelFormat::PADDLE;
|
||||
if (!UpdateModelResourceName(&model_name, ¶ms_name, &config_name,
|
||||
&model_format, config_info, false)) {
|
||||
return -1;
|
||||
}
|
||||
auto model_file = FLAGS_model + sep + model_name;
|
||||
auto params_file = FLAGS_model + sep + params_name;
|
||||
if (config_info["backend"] == "paddle_trt") {
|
||||
option.paddle_infer_option.collect_trt_shape = true;
|
||||
option.paddle_infer_option.collect_trt_shape_by_device = true;
|
||||
}
|
||||
if (config_info["backend"] == "paddle_trt" ||
|
||||
config_info["backend"] == "trt") {
|
||||
option.trt_option.SetShape("data", {34752, 5}, {34752, 5},
|
||||
{34752, 5});
|
||||
std::vector<float> min_input_data;
|
||||
ReadTestPoint(point_dir, min_input_data);
|
||||
// use custom data to perform collect shapes.
|
||||
option.trt_option.SetInputData("data", min_input_data);
|
||||
}
|
||||
auto model_centerpoint = vision::perception::Centerpoint(
|
||||
model_file, params_file, "", option, model_format);
|
||||
vision::PerceptionResult res;
|
||||
// Run profiling
|
||||
BENCHMARK_MODEL(model_centerpoint, model_centerpoint.Predict(point_dir, &res))
|
||||
// std::cout << res.Str() << std::endl;
|
||||
#endif
|
||||
|
||||
return 0;
|
||||
}
|
@@ -60,17 +60,24 @@ download PP-OCRv4-server-det.tgz
|
||||
|
||||
# PP-ShiTuV2
|
||||
download PP-ShiTuv2-rec.tgz
|
||||
download PP-ShiTuv2-det.tgz
|
||||
|
||||
# PP-StructureV2
|
||||
download PP-Structurev2-layout.tgz
|
||||
download PP-Structurev2-SLANet.tgz
|
||||
download PP-Structurev2-vi-layoutxlm.tgz
|
||||
|
||||
# Paddle3D
|
||||
download CADNN_OCRNet-HRNetW18.tgz
|
||||
download CenterPoint-Pillars-02Voxel.tgz
|
||||
download PETRv1_v99.tgz
|
||||
download PETRv2_v99.tgz
|
||||
|
||||
# Test resources
|
||||
# PaddleClas
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/ppcls_cls_demo.JPEG
|
||||
|
||||
# PaddleDetection
|
||||
# PaddleDetection & ppshitu-det
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/ppdet_det_img.jpg
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/ppdet_det_img_800x800.jpg
|
||||
|
||||
@@ -93,3 +100,9 @@ wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/structurev2_layout_v
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/structurev2_vi_layoutxml_zh_val_0.jpg
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/table_structure_dict_ch.txt
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/layout_cdla_dict.txt
|
||||
|
||||
# Paddle3D
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_cadnn_kitti_000780.png
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_petrv1_v99_nuscenes_sample_6.tgz && tar -zxvf paddle3d_petrv1_v99_nuscenes_sample_6.tgz
|
||||
wget https://bj.bcebos.com/paddlehub/fastdeploy_paddlex_2_0/paddle3d_petrv2_v99_nuscenes_sample_12.tgz && tar -zxvf paddle3d_petrv2_v99_nuscenes_sample_12.tgz
|
||||
|
@@ -25,6 +25,9 @@
|
||||
#include "paddle/extension.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace paddle_custom_ops {
|
||||
|
||||
std::vector<paddle::Tensor> postprocess_gpu(
|
||||
const std::vector<paddle::Tensor> &hm,
|
||||
const std::vector<paddle::Tensor> ®,
|
||||
@@ -97,11 +100,14 @@ std::vector<paddle::DataType> PostProcessInferDtype(
|
||||
return {reg_dtype[0], hm_dtype[0], paddle::DataType::INT64};
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
} // namespace paddle_custom_ops
|
||||
|
||||
PD_BUILD_OP(centerpoint_postprocess)
|
||||
.Inputs({paddle::Vec("HM"), paddle::Vec("REG"), paddle::Vec("HEIGHT"),
|
||||
paddle::Vec("DIM"), paddle::Vec("VEL"), paddle::Vec("ROT")})
|
||||
.Outputs({"BBOXES", "SCORES", "LABELS"})
|
||||
.SetKernelFn(PD_KERNEL(centerpoint_postprocess))
|
||||
.SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::centerpoint_postprocess))
|
||||
.Attrs({"voxel_size: std::vector<float>",
|
||||
"point_cloud_range: std::vector<float>",
|
||||
"post_center_range: std::vector<float>",
|
||||
@@ -109,7 +115,7 @@ PD_BUILD_OP(centerpoint_postprocess)
|
||||
"score_threshold: float", "nms_iou_threshold: float",
|
||||
"nms_pre_max_size: int", "nms_post_max_size: int",
|
||||
"with_velocity: bool"})
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(PostProcessInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(PostProcessInferDtype));
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::PostProcessInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::PostProcessInferDtype));
|
||||
|
||||
#endif // WITH_GPU
|
@@ -20,6 +20,9 @@
|
||||
#include "paddle/extension.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace paddle_custom_ops {
|
||||
|
||||
#define CHECK_INPUT_CUDA(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
|
||||
|
||||
#define CHECK_INPUT_BATCHSIZE(x) \
|
||||
@@ -284,3 +287,6 @@ std::vector<paddle::Tensor> postprocess_gpu(
|
||||
auto out_bboxes = paddle::experimental::concat(bboxes, 0);
|
||||
return {out_bboxes, out_scores, out_labels};
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
} // namespace paddle_custom_ops
|
@@ -19,6 +19,10 @@ All Rights Reserved 2019-2020.
|
||||
*/
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace paddle_custom_ops {
|
||||
|
||||
#define THREADS_PER_BLOCK 16
|
||||
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
|
||||
|
||||
@@ -315,3 +319,6 @@ void NmsLauncher(const cudaStream_t &stream, const float *bboxes,
|
||||
num_bboxes, num_bboxes_for_nms, nms_overlap_thresh, decode_bboxes_dims,
|
||||
bboxes, index, sorted_index, mask);
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
} // namespace paddle_custom_ops
|
@@ -14,8 +14,6 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#if defined(WITH_GPU)
|
||||
|
||||
#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
|
||||
#include "paddle/include/experimental/ext_all.h"
|
||||
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
|
||||
@@ -24,6 +22,9 @@
|
||||
#include "paddle/extension.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace paddle_custom_ops {
|
||||
|
||||
template <typename T, typename T_int>
|
||||
bool hard_voxelize_cpu_kernel(
|
||||
const T *points, const float point_cloud_range_x_min,
|
||||
@@ -147,7 +148,8 @@ std::vector<paddle::Tensor> hard_voxelize_cpu(
|
||||
return {voxels, coords, num_points_per_voxel, num_voxels};
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
|
||||
#if defined(PADDLE_WITH_CUDA) && defined(WITH_GPU)
|
||||
std::vector<paddle::Tensor> hard_voxelize_cuda(
|
||||
const paddle::Tensor &points, const std::vector<float> &voxel_size,
|
||||
const std::vector<float> &point_cloud_range, int max_num_points_in_voxel,
|
||||
@@ -161,7 +163,7 @@ std::vector<paddle::Tensor> hard_voxelize(
|
||||
if (points.is_cpu()) {
|
||||
return hard_voxelize_cpu(points, voxel_size, point_cloud_range,
|
||||
max_num_points_in_voxel, max_voxels);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#if defined(PADDLE_WITH_CUDA) && defined(WITH_GPU)
|
||||
} else if (points.is_gpu() || points.is_gpu_pinned()) {
|
||||
return hard_voxelize_cuda(points, voxel_size, point_cloud_range,
|
||||
max_num_points_in_voxel, max_voxels);
|
||||
@@ -188,14 +190,15 @@ std::vector<paddle::DataType> HardInferDtype(paddle::DataType points_dtype) {
|
||||
paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
} // namespace paddle_custom_ops
|
||||
|
||||
PD_BUILD_OP(hard_voxelize)
|
||||
.Inputs({"POINTS"})
|
||||
.Outputs({"VOXELS", "COORS", "NUM_POINTS_PER_VOXEL", "num_voxels"})
|
||||
.SetKernelFn(PD_KERNEL(hard_voxelize))
|
||||
.SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::hard_voxelize))
|
||||
.Attrs({"voxel_size: std::vector<float>",
|
||||
"point_cloud_range: std::vector<float>",
|
||||
"max_num_points_in_voxel: int", "max_voxels: int"})
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(HardInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(HardInferDtype));
|
||||
|
||||
#endif // WITH_GPU
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::HardInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::HardInferDtype));
|
@@ -20,6 +20,9 @@
|
||||
#include "paddle/extension.h"
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace paddle_custom_ops {
|
||||
|
||||
#define CHECK_INPUT_CUDA(x) \
|
||||
PD_CHECK(x.is_gpu() || x.is_gpu_pinned(), #x " must be a GPU Tensor.")
|
||||
|
||||
@@ -349,3 +352,6 @@ std::vector<paddle::Tensor> hard_voxelize_cuda(
|
||||
|
||||
return {voxels, coords, num_points_per_voxel, num_voxels};
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
} // namespace paddle_custom_ops
|
@@ -111,6 +111,8 @@ struct PaddleBackendOption {
|
||||
int gpu_mem_init_size = 100;
|
||||
/// The option to enable fixed size optimization for transformer model
|
||||
bool enable_fixed_size_opt = false;
|
||||
/// min_subgraph_size for paddle-trt
|
||||
int trt_min_subgraph_size = 3;
|
||||
|
||||
/// Disable type of operators run on TensorRT
|
||||
void DisableTrtOps(const std::vector<std::string>& ops) {
|
||||
|
@@ -59,6 +59,7 @@ void BindPaddleOption(pybind11::module& m) {
|
||||
&PaddleBackendOption::is_quantize_model)
|
||||
.def_readwrite("inference_precision", &PaddleBackendOption::inference_precision)
|
||||
.def_readwrite("enable_inference_cutlass",&PaddleBackendOption::enable_inference_cutlass)
|
||||
.def_readwrite("trt_min_subgraph_size",&PaddleBackendOption::trt_min_subgraph_size)
|
||||
.def("disable_trt_ops", &PaddleBackendOption::DisableTrtOps)
|
||||
.def("delete_pass", &PaddleBackendOption::DeletePass)
|
||||
.def("set_ipu_config", &PaddleBackendOption::SetIpuConfig);
|
||||
|
@@ -88,7 +88,8 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
||||
config_.SetOptimCacheDir(opt_cache_dir);
|
||||
}
|
||||
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
|
||||
option.trt_option.max_batch_size, 3,
|
||||
option.trt_option.max_batch_size,
|
||||
option.trt_min_subgraph_size,
|
||||
precision, use_static);
|
||||
SetTRTDynamicShapeToConfig(option);
|
||||
if (option_.enable_fixed_size_opt) {
|
||||
@@ -225,7 +226,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model,
|
||||
use_static = true;
|
||||
}
|
||||
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size,
|
||||
option.trt_option.max_batch_size, 3,
|
||||
option.trt_option.max_batch_size,
|
||||
option.trt_min_subgraph_size,
|
||||
paddle_infer::PrecisionType::kInt8,
|
||||
use_static, false);
|
||||
SetTRTDynamicShapeToConfig(option);
|
||||
|
Reference in New Issue
Block a user