[Backend] add paddle custom ops compatible policy (#2070)

* Add centerpoint

* fix postprocess op file name

* [Backend] add paddle custom ops compatible policy

* [Backend] add paddle custom ops compatible policy

* [Backend] add paddle custom ops compatible policy

* upgrade linx paddle gpu -> 2.5

* add custom op compatible policy for paddle 2.5

* add custom op compatible policy for paddle 2.5

* add custom op compatible policy for paddle 2.5

* add collect_trt_shape_by_device option for paddle backend

* add collect_trt_shape_by_device option for paddle backend

* add custom op option for python build

* fix python build bugs

* update paddle linux x86 cpu only lib

* update paddle linux gpu lib

* update patchelf cmake

* fix paddle backend option pybind

* update paddle_inference.cmake

* add cuda sm_80 support (A100)

---------

Co-authored-by: zengshao0622 <peter_z96@163.com>
Co-authored-by: qiuyanjun <qiuyanjun@baidu.com>
This commit is contained in:
DefTruth
2023-06-29 22:32:14 +08:00
committed by GitHub
parent 4c3e7030e1
commit b2426aefa9
16 changed files with 1423 additions and 42 deletions

View File

@@ -36,7 +36,9 @@ include(${PROJECT_SOURCE_DIR}/cmake/utils.cmake)
# Set C++11 as standard for the whole project
if(NOT MSVC)
set(CMAKE_CXX_STANDARD 11)
if(NOT DEFINED CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 11)
endif()
set(CMAKE_CXX_FLAGS "-Wno-format -g0 -O3")
if(NEED_ABI0)
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
@@ -177,21 +179,26 @@ file(GLOB_RECURSE DEPLOY_PIPELINE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/f
file(GLOB_RECURSE DEPLOY_VISION_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cc)
file(GLOB_RECURSE DEPLOY_TEXT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cc)
file(GLOB_RECURSE DEPLOY_PYBIND_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/*.cc ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*_pybind.cc)
file(GLOB_RECURSE DEPLOY_PADDLE_CUSTOM_OP_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/paddle/ops/*.cc)
if(WITH_GPU)
file(GLOB_RECURSE DEPLOY_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cu)
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_CUDA_SRCS})
file(GLOB_RECURSE DEPLOY_PADDLE_CUSTOM_OP_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/runtime/backends/paddle/ops/*.cu)
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_PADDLE_CUSTOM_OP_CUDA_SRCS})
file(GLOB_RECURSE DEPLOY_VISION_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/vision/*.cu)
list(APPEND DEPLOY_VISION_SRCS ${DEPLOY_VISION_CUDA_SRCS})
file(GLOB_RECURSE DEPLOY_TEXT_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/text/*.cu)
list(APPEND DEPLOY_TEXT_SRCS ${DEPLOY_TEXT_CUDA_SRCS})
endif()
list(REMOVE_ITEM DEPLOY_PADDLE_SRCS ${DEPLOY_PADDLE_CUSTOM_OP_SRCS})
list(REMOVE_ITEM ALL_DEPLOY_SRCS ${DEPLOY_ORT_SRCS} ${DEPLOY_PADDLE_SRCS}
${DEPLOY_POROS_SRCS} ${DEPLOY_TRT_SRCS}
${DEPLOY_OPENVINO_SRCS} ${DEPLOY_LITE_SRCS}
${DEPLOY_VISION_SRCS} ${DEPLOY_TEXT_SRCS}
${DEPLOY_PIPELINE_SRCS} ${DEPLOY_RKNPU2_SRCS}
${DEPLOY_SOPHGO_SRCS} ${DEPLOY_ENCRYPTION_SRCS}
${DEPLOY_HORIZON_SRCS} ${DEPLOY_TVM_SRCS})
${DEPLOY_HORIZON_SRCS} ${DEPLOY_TVM_SRCS}
${DEPLOY_PADDLE_CUSTOM_OP_SRCS})
set(DEPEND_LIBS "")
@@ -243,6 +250,13 @@ if(ENABLE_PADDLE_BACKEND)
if(external_ort_FOUND)
list(APPEND DEPEND_LIBS external_p2o external_ort)
endif()
if(PADDLEINFERENCE_API_CUSTOM_OP)
set_paddle_custom_ops_compatible_policy()
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_PADDLE_CUSTOM_OP_SRCS})
if(WITH_GPU)
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_PADDLE_CUSTOM_OP_CUDA_SRCS})
endif()
endif()
endif()
if(ENABLE_OPENVINO_BACKEND)
@@ -521,7 +535,7 @@ endif()
target_link_libraries(${LIBRARY_NAME} ${DEPEND_LIBS})
if(ENABLE_PADDLE_BACKEND)
set_paddle_encrypt_auth_link_policy(${LIBRARY_NAME})
set_paddle_encrypt_auth_compatible_policy(${LIBRARY_NAME})
endif()
if(ANDROID)
@@ -751,3 +765,4 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
message(FATAL_ERROR "[ERROR] FastDeploy require g++ version >= 5.4.0, but now your g++ version is ${CMAKE_CXX_COMPILER_VERSION}, this may cause failure! Use -DCMAKE_CXX_COMPILER to define path of your compiler.")
endif()
endif()

View File

@@ -61,6 +61,9 @@ DEFINE_string(optimized_model_dir, "",
"Optional, set optimized model dir for lite."
"eg: model.opt.nb, "
"default ''");
DEFINE_bool(collect_trt_shape_by_device, false,
"Optional, whether collect trt shape by device. "
"default false.");
#if defined(ENABLE_BENCHMARK)
static std::vector<int64_t> GetInt64Shape(const std::vector<int>& shape) {
@@ -188,6 +191,8 @@ static void RuntimeProfiling(int argc, char* argv[]) {
// Set tensorrt shapes
if (config_info["backend"] == "paddle_trt") {
option.paddle_infer_option.collect_trt_shape = true;
option.paddle_infer_option.collect_trt_shape_by_device =
FLAGS_collect_trt_shape_by_device;
}
if (config_info["backend"] == "paddle_trt" ||
config_info["backend"] == "trt") {

View File

@@ -246,7 +246,11 @@ message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}")
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here.
set(CMAKE_CUDA_STANDARD 11)
if(NOT DEFINED CMAKE_CUDA_STANDARD)
set(CMAKE_CUDA_STANDARD 11)
else()
message(WARNING "Detected custom CMAKE_CUDA_STANDARD is using: ${CMAKE_CUDA_STANDARD}")
endif()
# (Note) For windows, if delete /W[1-4], /W1 will be added defaultly and conflic with -w
# So replace /W[1-4] with /W0

View File

@@ -22,7 +22,11 @@ endif()
# Custom options for Paddle Inference backend
option(PADDLEINFERENCE_DIRECTORY "Directory of custom Paddle Inference library" OFF)
option(PADDLEINFERENCE_API_CUSTOM_OP "Whether building with custom paddle ops" OFF)
option(PADDLEINFERENCE_API_COMPAT_2_4_x "Whether using Paddle Inference 2.4.x" OFF)
option(PADDLEINFERENCE_API_COMPAT_2_5_x "Whether using Paddle Inference 2.5.x" OFF)
option(PADDLEINFERENCE_API_COMPAT_DEV "Whether using Paddle Inference latest dev" OFF)
option(PADDLEINFERENCE_API_COMPAT_CUDA_SM_80 "Whether using Paddle Inference with CUDA sm_80(A100)" OFF)
set(PADDLEINFERENCE_PROJECT "extern_paddle_inference")
set(PADDLEINFERENCE_PREFIX_DIR ${THIRD_PARTY_PATH}/paddle_inference)
@@ -93,11 +97,16 @@ else()
else()
# x86_64
if(WITH_GPU)
set(PADDLEINFERENCE_FILE "paddle_inference-linux-x64-gpu-trt8.5.2.2-mkl-avx-0.0.0.660f781b77.tgz")
set(PADDLEINFERENCE_VERSION "0.0.0.660f781b77")
if(PADDLEINFERENCE_API_COMPAT_CUDA_SM_80)
set(PADDLEINFERENCE_FILE "paddle_inference-linux-x64-gpu-trt8.5.2.2-mkl-sm70.sm75.sm80.sm86.nodist-2.5.0.558ae9cd11.tgz")
set(PADDLEINFERENCE_VERSION "2.5.0.558ae9cd11")
else()
set(PADDLEINFERENCE_FILE "paddle_inference-linux-x64-gpu-trt8.5.2.2-mkl-sm61.sm70.sm75.sm86.nodist-2.5.0.558ae9cd11.tgz")
set(PADDLEINFERENCE_VERSION "2.5.0.558ae9cd11")
endif()
else()
set(PADDLEINFERENCE_FILE "paddle_inference-linux-x64-mkl-avx-0.0.0.660f781b77.tgz")
set(PADDLEINFERENCE_VERSION "0.0.0.660f781b77")
set(PADDLEINFERENCE_FILE "paddle_inference-linux-x64-mkl-2.5.0.558ae9cd11.tgz")
set(PADDLEINFERENCE_VERSION "2.5.0.558ae9cd11")
endif()
if(WITH_IPU)
set(PADDLEINFERENCE_FILE "paddle_inference-linux-x64-ipu-2.4-dev1.tgz")
@@ -173,18 +182,26 @@ else()
set(FDMODEL_LEVELDB_LIB_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/leveldb/lib/libleveldb.a")
if((EXISTS ${FDMODEL_LIB}) AND (EXISTS ${FDMODEL_MODEL_LIB}))
set(PADDLEINFERENCE_WITH_ENCRYPT ON CACHE BOOL "" FORCE)
message(STATUS "Detected ${FDMODEL_LIB} and ${FDMODEL_MODEL_LIB} exists, fource PADDLEINFERENCE_WITH_ENCRYPT=${PADDLEINFERENCE_WITH_ENCRYPT}")
message(STATUS "Detected ${FDMODEL_LIB} and ${FDMODEL_MODEL_LIB} exists, force PADDLEINFERENCE_WITH_ENCRYPT=${PADDLEINFERENCE_WITH_ENCRYPT}")
endif()
if((EXISTS ${FDMODEL_LIB}) AND (EXISTS ${FDMODEL_AUTH_LIB}))
set(PADDLEINFERENCE_WITH_AUTH ON CACHE BOOL "" FORCE)
message(STATUS "Detected ${FDMODEL_LIB} and ${FDMODEL_AUTH_LIB} exists, fource PADDLEINFERENCE_WITH_AUTH=${PADDLEINFERENCE_WITH_AUTH}")
message(STATUS "Detected ${FDMODEL_LIB} and ${FDMODEL_AUTH_LIB} exists, force PADDLEINFERENCE_WITH_AUTH=${PADDLEINFERENCE_WITH_AUTH}")
endif()
endif()
endif(WIN32)
# Path Paddle Inference ELF lib file
if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
add_custom_target(patchelf_paddle_inference ALL COMMAND bash -c "PATCHELF_EXE=${PATCHELF_EXE} python ${PROJECT_SOURCE_DIR}/scripts/patch_paddle_inference.py ${PADDLEINFERENCE_INSTALL_DIR}/paddle/lib/libpaddle_inference.so" DEPENDS ${LIBRARY_NAME})
set(PATCHELF_SCRIPT ${PROJECT_SOURCE_DIR}/scripts/patch_paddle_inference.py)
set(PATCHELF_TARGET ${PADDLEINFERENCE_INSTALL_DIR}/paddle/lib/libpaddle_inference.so)
add_custom_target(
patchelf_paddle_inference ALL COMMAND bash -c
"PATCHELF_EXE=${PATCHELF_EXE} python ${PATCHELF_SCRIPT} ${PATCHELF_TARGET}"
DEPENDS ${LIBRARY_NAME}
)
unset(PATCHELF_SCRIPT)
unset(PATCHELF_TARGET)
endif()
add_library(external_paddle_inference STATIC IMPORTED GLOBAL)
@@ -232,7 +249,8 @@ if(PADDLEINFERENCE_WITH_AUTH)
list(APPEND ENCRYPT_AUTH_LIBS external_fdmodel_auth)
endif()
function(set_paddle_encrypt_auth_link_policy LIBRARY_NAME)
# Compatible policy for paddle with encrypt and auth
function(set_paddle_encrypt_auth_compatible_policy LIBRARY_NAME)
if(ENABLE_PADDLE_BACKEND AND (PADDLEINFERENCE_WITH_ENCRYPT OR PADDLEINFERENCE_WITH_AUTH))
target_link_libraries(${LIBRARY_NAME} ${ENCRYPT_AUTH_LIBS})
# Note(qiuyanjun): Currently, for XPU, we need to manually link the whole
@@ -249,13 +267,20 @@ function(set_paddle_encrypt_auth_link_policy LIBRARY_NAME)
endif()
endfunction()
# Backward compatible for 2.4.x
string(FIND ${PADDLEINFERENCE_VERSION} "2.4" PADDLEINFERENCE_USE_2_4_x)
string(FIND ${PADDLEINFERENCE_VERSION} "2.5" PADDLEINFERENCE_USE_2_5_x)
string(FIND ${PADDLEINFERENCE_VERSION} "0.0.0" PADDLEINFERENCE_USE_DEV)
# Compatible policy for 2.4.x/2.5.x and latest dev.
string(REGEX MATCH "0.0.0" PADDLEINFERENCE_USE_DEV ${PADDLEINFERENCE_VERSION})
string(REGEX MATCH "2.4|post24|post2.4" PADDLEINFERENCE_USE_2_4_x ${PADDLEINFERENCE_VERSION})
string(REGEX MATCH "2.5|post25|post2.5" PADDLEINFERENCE_USE_2_5_x ${PADDLEINFERENCE_VERSION})
if((NOT (PADDLEINFERENCE_USE_2_4_x EQUAL -1))
AND (PADDLEINFERENCE_USE_2_5_x EQUAL -1) AND (PADDLEINFERENCE_USE_DEV EQUAL -1))
if(PADDLEINFERENCE_USE_DEV)
set(PADDLEINFERENCE_API_COMPAT_DEV ON CACHE BOOL "" FORCE)
endif()
if(PADDLEINFERENCE_USE_2_5_x)
set(PADDLEINFERENCE_API_COMPAT_2_5_x ON CACHE BOOL "" FORCE)
endif()
if(PADDLEINFERENCE_USE_2_4_x AND (NOT PADDLEINFERENCE_API_COMPAT_2_5_x) AND (NOT PADDLEINFERENCE_API_COMPAT_DEV))
set(PADDLEINFERENCE_API_COMPAT_2_4_x ON CACHE BOOL "" FORCE)
message(WARNING "You are using PADDLEINFERENCE_USE_2_4_x:${PADDLEINFERENCE_VERSION}, force PADDLEINFERENCE_API_COMPAT_2_4_x=ON")
endif()
@@ -263,3 +288,55 @@ endif()
if(PADDLEINFERENCE_API_COMPAT_2_4_x)
add_definitions(-DPADDLEINFERENCE_API_COMPAT_2_4_x)
endif()
if(PADDLEINFERENCE_API_COMPAT_2_5_x)
add_definitions(-DPADDLEINFERENCE_API_COMPAT_2_5_x)
endif()
if(PADDLEINFERENCE_API_COMPAT_DEV)
add_definitions(-DPADDLEINFERENCE_API_COMPAT_DEV)
endif()
# Compatible policy for custom paddle ops
if(PADDLEINFERENCE_API_COMPAT_2_5_x)
# no c++ standard policy conflicts vs c++ 11
# TODO: support custom ops for latest dev
set(PADDLEINFERENCE_API_CUSTOM_OP ON CACHE BOOL "" FORCE)
# add paddle_inference/paddle/include path for custom ops
# the extension.h and it's deps headers are located in
# paddle/include/paddle directory.
include_directories(${PADDLEINFERENCE_INC_DIR}/paddle/include)
message(WARNING "You are using PADDLEINFERENCE_API_COMPAT_2_5_x:${PADDLEINFERENCE_VERSION}, force PADDLEINFERENCE_API_CUSTOM_OP=${PADDLEINFERENCE_API_CUSTOM_OP}")
endif()
function(set_paddle_custom_ops_compatible_policy)
if(PADDLEINFERENCE_API_CUSTOM_OP AND (NOT MSVC))
# TODO: add non c++ 14 policy for latest dev
if(NOT PADDLEINFERENCE_API_COMPAT_2_5_x)
# gcc c++ 14 policy for 2.4.x
if(NOT DEFINED CMAKE_CXX_STANDARD)
set(CMAKE_CXX_STANDARD 14 PARENT_SCOPE)
message(WARNING "Found PADDLEINFERENCE_API_CUSTOM_OP=ON, but CMAKE_CXX_STANDARD is not defined, use c++ 14 by default!")
elseif(NOT (CMAKE_CXX_STANDARD EQUAL 14))
set(CMAKE_CXX_STANDARD 14 PARENT_SCOPE)
message(WARNING "Found PADDLEINFERENCE_API_CUSTOM_OP=ON, force use c++ 14!")
endif()
endif()
if(WITH_GPU)
# cuda c++ 14 policy for 2.4.x
if(NOT PADDLEINFERENCE_API_COMPAT_2_5_x)
if(NOT DEFINED CMAKE_CUDA_STANDARD)
set(CMAKE_CUDA_STANDARD 14 PARENT_SCOPE)
message(WARNING "Found PADDLEINFERENCE_API_CUSTOM_OP=ON and WITH_GPU=ON, but CMAKE_CUDA_STANDARD is not defined, use c++ 14 by default!")
elseif(NOT (CMAKE_CUDA_STANDARD EQUAL 14))
set(CMAKE_CUDA_STANDARD 14 PARENT_SCOPE)
message(WARNING "Found PADDLEINFERENCE_API_CUSTOM_OP=ON and WITH_GPU=ON, force use c++ 14!")
endif()
endif()
# compile flags for paddle custom ops
add_definitions(-DPADDLE_WITH_CUDA)
add_definitions(-DPADDLE_ON_INFERENCE)
add_definitions(-DPADDLE_NO_PYTHON)
endif()
endif()
endfunction()

View File

@@ -19,6 +19,8 @@ function(fastdeploy_summary)
message(STATUS " CMake command : ${CMAKE_COMMAND}")
message(STATUS " System : ${CMAKE_SYSTEM_NAME}")
message(STATUS " C++ compiler : ${CMAKE_CXX_COMPILER}")
message(STATUS " C++ standard : ${CMAKE_CXX_STANDARD}")
message(STATUS " C++ cuda standard : ${CMAKE_CUDA_STANDARD}")
message(STATUS " C++ compiler version : ${CMAKE_CXX_COMPILER_VERSION}")
message(STATUS " CXX flags : ${CMAKE_CXX_FLAGS}")
message(STATUS " EXE linker flags : ${CMAKE_EXE_LINKER_FLAGS}")

View File

@@ -0,0 +1,115 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(WITH_GPU)
#include <cuda.h>
#include <cuda_runtime_api.h>
#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
#include "paddle/include/paddle/extension.h"
#else
#include "paddle/extension.h"
#endif
std::vector<paddle::Tensor> postprocess_gpu(
const std::vector<paddle::Tensor> &hm,
const std::vector<paddle::Tensor> &reg,
const std::vector<paddle::Tensor> &height,
const std::vector<paddle::Tensor> &dim,
const std::vector<paddle::Tensor> &vel,
const std::vector<paddle::Tensor> &rot,
const std::vector<float> &voxel_size,
const std::vector<float> &point_cloud_range,
const std::vector<float> &post_center_range,
const std::vector<int> &num_classes, const int down_ratio,
const float score_threshold, const float nms_iou_threshold,
const int nms_pre_max_size, const int nms_post_max_size,
const bool with_velocity);
std::vector<paddle::Tensor> centerpoint_postprocess(
const std::vector<paddle::Tensor> &hm,
const std::vector<paddle::Tensor> &reg,
const std::vector<paddle::Tensor> &height,
const std::vector<paddle::Tensor> &dim,
const std::vector<paddle::Tensor> &vel,
const std::vector<paddle::Tensor> &rot,
const std::vector<float> &voxel_size,
const std::vector<float> &point_cloud_range,
const std::vector<float> &post_center_range,
const std::vector<int> &num_classes, const int down_ratio,
const float score_threshold, const float nms_iou_threshold,
const int nms_pre_max_size, const int nms_post_max_size,
const bool with_velocity) {
if (hm[0].is_gpu()) {
return postprocess_gpu(hm, reg, height, dim, vel, rot, voxel_size,
point_cloud_range, post_center_range, num_classes,
down_ratio, score_threshold, nms_iou_threshold,
nms_pre_max_size, nms_post_max_size, with_velocity);
} else {
PD_THROW(
"Unsupported device type for centerpoint postprocess "
"operator.");
}
}
std::vector<std::vector<int64_t>> PostProcessInferShape(
const std::vector<std::vector<int64_t>> &hm_shape,
const std::vector<std::vector<int64_t>> &reg_shape,
const std::vector<std::vector<int64_t>> &height_shape,
const std::vector<std::vector<int64_t>> &dim_shape,
const std::vector<std::vector<int64_t>> &vel_shape,
const std::vector<std::vector<int64_t>> &rot_shape,
const std::vector<float> &voxel_size,
const std::vector<float> &point_cloud_range,
const std::vector<float> &post_center_range,
const std::vector<int> &num_classes, const int down_ratio,
const float score_threshold, const float nms_iou_threshold,
const int nms_pre_max_size, const int nms_post_max_size,
const bool with_velocity) {
if (with_velocity) {
return {{-1, 9}, {-1}, {-1}};
} else {
return {{-1, 7}, {-1}, {-1}};
}
}
std::vector<paddle::DataType> PostProcessInferDtype(
const std::vector<paddle::DataType> &hm_dtype,
const std::vector<paddle::DataType> &reg_dtype,
const std::vector<paddle::DataType> &height_dtype,
const std::vector<paddle::DataType> &dim_dtype,
const std::vector<paddle::DataType> &vel_dtype,
const std::vector<paddle::DataType> &rot_dtype) {
return {reg_dtype[0], hm_dtype[0], paddle::DataType::INT64};
}
PD_BUILD_OP(centerpoint_postprocess)
.Inputs({paddle::Vec("HM"), paddle::Vec("REG"), paddle::Vec("HEIGHT"),
paddle::Vec("DIM"), paddle::Vec("VEL"), paddle::Vec("ROT")})
.Outputs({"BBOXES", "SCORES", "LABELS"})
.SetKernelFn(PD_KERNEL(centerpoint_postprocess))
.Attrs({"voxel_size: std::vector<float>",
"point_cloud_range: std::vector<float>",
"post_center_range: std::vector<float>",
"num_classes: std::vector<int>", "down_ratio: int",
"score_threshold: float", "nms_iou_threshold: float",
"nms_pre_max_size: int", "nms_post_max_size: int",
"with_velocity: bool"})
.SetInferShapeFn(PD_INFER_SHAPE(PostProcessInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PostProcessInferDtype));
#endif // WITH_GPU

View File

@@ -0,0 +1,286 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
#include "paddle/include/paddle/extension.h"
#else
#include "paddle/extension.h"
#endif
#define CHECK_INPUT_CUDA(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
#define CHECK_INPUT_BATCHSIZE(x) \
PD_CHECK(x.shape()[0] == 1, #x " batch size must be 1.")
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
const int THREADS_PER_BLOCK_NMS = sizeof(int64_t) * 8;
void NmsLauncher(const cudaStream_t &stream, const float *bboxes,
const int *index, const int64_t *sorted_index,
const int num_bboxes, const int num_bboxes_for_nms,
const float nms_overlap_thresh, const int decode_bboxes_dims,
int64_t *mask);
__global__ void decode_kernel(
const float *score, const float *reg, const float *height, const float *dim,
const float *vel, const float *rot, const float score_threshold,
const int feat_w, const float down_ratio, const float voxel_size_x,
const float voxel_size_y, const float point_cloud_range_x_min,
const float point_cloud_range_y_min, const float post_center_range_x_min,
const float post_center_range_y_min, const float post_center_range_z_min,
const float post_center_range_x_max, const float post_center_range_y_max,
const float post_center_range_z_max, const int num_bboxes,
const bool with_velocity, const int decode_bboxes_dims, float *bboxes,
bool *mask, int *score_idx) {
int box_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (box_idx == num_bboxes || box_idx > num_bboxes) {
return;
}
const int xs = box_idx % feat_w;
const int ys = box_idx / feat_w;
float x = reg[box_idx];
float y = reg[box_idx + num_bboxes];
float z = height[box_idx];
bboxes[box_idx * decode_bboxes_dims] =
(x + xs) * down_ratio * voxel_size_x + point_cloud_range_x_min;
bboxes[box_idx * decode_bboxes_dims + 1] =
(y + ys) * down_ratio * voxel_size_y + point_cloud_range_y_min;
bboxes[box_idx * decode_bboxes_dims + 2] = z;
bboxes[box_idx * decode_bboxes_dims + 3] = dim[box_idx];
bboxes[box_idx * decode_bboxes_dims + 4] = dim[box_idx + num_bboxes];
bboxes[box_idx * decode_bboxes_dims + 5] = dim[box_idx + 2 * num_bboxes];
if (with_velocity) {
bboxes[box_idx * decode_bboxes_dims + 6] = vel[box_idx];
bboxes[box_idx * decode_bboxes_dims + 7] = vel[box_idx + num_bboxes];
bboxes[box_idx * decode_bboxes_dims + 8] =
atan2f(rot[box_idx], rot[box_idx + num_bboxes]);
} else {
bboxes[box_idx * decode_bboxes_dims + 6] =
atan2f(rot[box_idx], rot[box_idx + num_bboxes]);
}
if (score[box_idx] > score_threshold && x <= post_center_range_x_max &&
y <= post_center_range_y_max && z <= post_center_range_z_max &&
x >= post_center_range_x_min && y >= post_center_range_y_min &&
z >= post_center_range_z_min) {
mask[box_idx] = true;
}
score_idx[box_idx] = box_idx;
}
void DecodeLauncher(
const cudaStream_t &stream, const float *score, const float *reg,
const float *height, const float *dim, const float *vel, const float *rot,
const float score_threshold, const int feat_w, const float down_ratio,
const float voxel_size_x, const float voxel_size_y,
const float point_cloud_range_x_min, const float point_cloud_range_y_min,
const float post_center_range_x_min, const float post_center_range_y_min,
const float post_center_range_z_min, const float post_center_range_x_max,
const float post_center_range_y_max, const float post_center_range_z_max,
const int num_bboxes, const bool with_velocity,
const int decode_bboxes_dims, float *bboxes, bool *mask, int *score_idx) {
dim3 blocks(DIVUP(num_bboxes, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS);
decode_kernel<<<blocks, threads, 0, stream>>>(
score, reg, height, dim, vel, rot, score_threshold, feat_w, down_ratio,
voxel_size_x, voxel_size_y, point_cloud_range_x_min,
point_cloud_range_y_min, post_center_range_x_min, post_center_range_y_min,
post_center_range_z_min, post_center_range_x_max, post_center_range_y_max,
post_center_range_z_max, num_bboxes, with_velocity, decode_bboxes_dims,
bboxes, mask, score_idx);
}
std::vector<paddle::Tensor> postprocess_gpu(
const std::vector<paddle::Tensor> &hm,
const std::vector<paddle::Tensor> &reg,
const std::vector<paddle::Tensor> &height,
const std::vector<paddle::Tensor> &dim,
const std::vector<paddle::Tensor> &vel,
const std::vector<paddle::Tensor> &rot,
const std::vector<float> &voxel_size,
const std::vector<float> &point_cloud_range,
const std::vector<float> &post_center_range,
const std::vector<int> &num_classes, const int down_ratio,
const float score_threshold, const float nms_iou_threshold,
const int nms_pre_max_size, const int nms_post_max_size,
const bool with_velocity) {
int num_tasks = hm.size();
int decode_bboxes_dims = 9;
if (!with_velocity) {
decode_bboxes_dims = 7;
}
float voxel_size_x = voxel_size[0];
float voxel_size_y = voxel_size[1];
float point_cloud_range_x_min = point_cloud_range[0];
float point_cloud_range_y_min = point_cloud_range[1];
float post_center_range_x_min = post_center_range[0];
float post_center_range_y_min = post_center_range[1];
float post_center_range_z_min = post_center_range[2];
float post_center_range_x_max = post_center_range[3];
float post_center_range_y_max = post_center_range[4];
float post_center_range_z_max = post_center_range[5];
std::vector<paddle::Tensor> scores;
std::vector<paddle::Tensor> labels;
std::vector<paddle::Tensor> bboxes;
for (int task_id = 0; task_id < num_tasks; ++task_id) {
CHECK_INPUT_BATCHSIZE(hm[0]);
int feat_h = hm[0].shape()[2];
int feat_w = hm[0].shape()[3];
int num_bboxes = feat_h * feat_w;
// score and label
auto sigmoid_hm_per_task = paddle::experimental::sigmoid(hm[task_id]);
auto label_per_task =
paddle::experimental::argmax(sigmoid_hm_per_task, 1, true, false, 3);
auto score_per_task =
paddle::experimental::max(sigmoid_hm_per_task, {1}, true);
// dim
auto exp_dim_per_task = paddle::experimental::exp(dim[task_id]);
// decode bboxed and get mask of bboxes for nms
const float *score_ptr = score_per_task.data<float>();
const float *reg_ptr = reg[task_id].data<float>();
const float *height_ptr = height[task_id].data<float>();
// const float* dim_ptr = dim[task_id].data<float>();
const float *exp_dim_per_task_ptr = exp_dim_per_task.data<float>();
const float *vel_ptr = vel[task_id].data<float>();
const float *rot_ptr = rot[task_id].data<float>();
auto decode_bboxes =
paddle::empty({num_bboxes, decode_bboxes_dims},
paddle::DataType::FLOAT32, paddle::GPUPlace());
float *decode_bboxes_ptr = decode_bboxes.data<float>();
auto thresh_mask = paddle::full({num_bboxes}, 0, paddle::DataType::BOOL,
paddle::GPUPlace());
bool *thresh_mask_ptr = thresh_mask.data<bool>();
auto score_idx = paddle::empty({num_bboxes}, paddle::DataType::INT32,
paddle::GPUPlace());
int *score_idx_ptr = score_idx.data<int32_t>();
DecodeLauncher(score_per_task.stream(), score_ptr, reg_ptr, height_ptr,
exp_dim_per_task_ptr, vel_ptr, rot_ptr, score_threshold,
feat_w, down_ratio, voxel_size_x, voxel_size_y,
point_cloud_range_x_min, point_cloud_range_y_min,
post_center_range_x_min, post_center_range_y_min,
post_center_range_z_min, post_center_range_x_max,
post_center_range_y_max, post_center_range_z_max, num_bboxes,
with_velocity, decode_bboxes_dims, decode_bboxes_ptr,
thresh_mask_ptr, score_idx_ptr);
// select score by mask
auto selected_score_idx =
paddle::experimental::masked_select(score_idx, thresh_mask);
auto flattened_selected_score =
paddle::experimental::reshape(score_per_task, {num_bboxes});
auto selected_score = paddle::experimental::masked_select(
flattened_selected_score, thresh_mask);
int num_selected = selected_score.numel();
if (num_selected == 0 || num_selected < 0) {
auto fake_out_boxes =
paddle::full({1, decode_bboxes_dims}, 0., paddle::DataType::FLOAT32,
paddle::GPUPlace());
auto fake_out_score =
paddle::full({1}, -1., paddle::DataType::FLOAT32, paddle::GPUPlace());
auto fake_out_label =
paddle::full({1}, 0, paddle::DataType::INT64, paddle::GPUPlace());
scores.push_back(fake_out_score);
labels.push_back(fake_out_label);
bboxes.push_back(fake_out_boxes);
continue;
}
// sort score by descending
auto sort_out = paddle::experimental::argsort(selected_score, 0, true);
auto sorted_index = std::get<1>(sort_out);
int num_bboxes_for_nms =
num_selected > nms_pre_max_size ? nms_pre_max_size : num_selected;
// nms
// in NmsLauncher, rot = - theta - pi / 2
const int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS);
auto nms_mask = paddle::empty({num_bboxes_for_nms * col_blocks},
paddle::DataType::INT64, paddle::GPUPlace());
int64_t *nms_mask_data = nms_mask.data<int64_t>();
NmsLauncher(score_per_task.stream(), decode_bboxes.data<float>(),
selected_score_idx.data<int>(), sorted_index.data<int64_t>(),
num_selected, num_bboxes_for_nms, nms_iou_threshold,
decode_bboxes_dims, nms_mask_data);
const paddle::Tensor nms_mask_cpu_tensor =
nms_mask.copy_to(paddle::CPUPlace(), true);
const int64_t *nms_mask_cpu = nms_mask_cpu_tensor.data<int64_t>();
auto remv_cpu = paddle::full({col_blocks}, 0, paddle::DataType::INT64,
paddle::CPUPlace());
int64_t *remv_cpu_data = remv_cpu.data<int64_t>();
int num_to_keep = 0;
auto keep = paddle::empty({num_bboxes_for_nms}, paddle::DataType::INT32,
paddle::CPUPlace());
int *keep_data = keep.data<int>();
for (int i = 0; i < num_bboxes_for_nms; i++) {
int nblock = i / THREADS_PER_BLOCK_NMS;
int inblock = i % THREADS_PER_BLOCK_NMS;
if (!(remv_cpu_data[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
const int64_t *p = &nms_mask_cpu[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu_data[j] |= p[j];
}
}
}
int num_for_gather =
num_to_keep > nms_post_max_size ? nms_post_max_size : num_to_keep;
auto keep_gpu = paddle::empty({num_for_gather}, paddle::DataType::INT32,
paddle::GPUPlace());
int *keep_gpu_ptr = keep_gpu.data<int>();
cudaMemcpy(keep_gpu_ptr, keep_data, num_for_gather * sizeof(int),
cudaMemcpyHostToDevice);
auto gather_sorted_index =
paddle::experimental::gather(sorted_index, keep_gpu, 0);
auto gather_index = paddle::experimental::gather(selected_score_idx,
gather_sorted_index, 0);
auto gather_score =
paddle::experimental::gather(selected_score, gather_sorted_index, 0);
auto flattened_label =
paddle::experimental::reshape(label_per_task, {num_bboxes});
auto gather_label =
paddle::experimental::gather(flattened_label, gather_index, 0);
auto gather_bbox =
paddle::experimental::gather(decode_bboxes, gather_index, 0);
auto start_label = paddle::full(
{1}, num_classes[task_id], paddle::DataType::INT64, paddle::GPUPlace());
auto added_label = paddle::experimental::add(gather_label, start_label);
scores.push_back(gather_score);
labels.push_back(added_label);
bboxes.push_back(gather_bbox);
}
auto out_scores = paddle::experimental::concat(scores, 0);
auto out_labels = paddle::experimental::concat(labels, 0);
auto out_bboxes = paddle::experimental::concat(bboxes, 0);
return {out_bboxes, out_scores, out_labels};
}

View File

@@ -0,0 +1,309 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#include <stdio.h>
#define THREADS_PER_BLOCK 16
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
const int THREADS_PER_BLOCK_NMS = sizeof(int64_t) * 8;
const float EPS = 1e-8;
struct Point {
float x, y;
__device__ Point() {}
__device__ Point(double _x, double _y) { x = _x, y = _y; }
__device__ void set(float _x, float _y) {
x = _x;
y = _y;
}
__device__ Point operator+(const Point &b) const {
return Point(x + b.x, y + b.y);
}
__device__ Point operator-(const Point &b) const {
return Point(x - b.x, y - b.y);
}
};
__device__ inline float cross(const Point &a, const Point &b) {
return a.x * b.y - a.y * b.x;
}
__device__ inline float cross(const Point &p1, const Point &p2,
const Point &p0) {
return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
}
__device__ int check_rect_cross(const Point &p1, const Point &p2,
const Point &q1, const Point &q2) {
int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) &&
min(q1.x, q2.x) <= max(p1.x, p2.x) &&
min(p1.y, p2.y) <= max(q1.y, q2.y) &&
min(q1.y, q2.y) <= max(p1.y, p2.y);
return ret;
}
__device__ inline int check_in_box2d(const float *box, const Point &p) {
// params: (7) [x, y, z, dx, dy, dz, heading]
const float MARGIN = 1e-2;
float center_x = box[0], center_y = box[1];
// rotate the point in the opposite direction of box
float angle_cos = cos(-box[6]), angle_sin = sin(-box[6]);
float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin);
float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos;
return (fabs(rot_x) < box[3] / 2 + MARGIN &&
fabs(rot_y) < box[4] / 2 + MARGIN);
}
__device__ inline int intersection(const Point &p1, const Point &p0,
const Point &q1, const Point &q0,
Point *ans) {
// fast exclusion
if (check_rect_cross(p0, p1, q0, q1) == 0) return 0;
// check cross standing
float s1 = cross(q0, p1, p0);
float s2 = cross(p1, q1, p0);
float s3 = cross(p0, q1, q0);
float s4 = cross(q1, p1, q0);
if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0;
// calculate intersection of two lines
float s5 = cross(q1, p1, p0);
if (fabs(s5 - s1) > EPS) {
ans->x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
ans->y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
} else {
float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
float D = a0 * b1 - a1 * b0;
ans->x = (b0 * c1 - b1 * c0) / D;
ans->y = (a1 * c0 - a0 * c1) / D;
}
return 1;
}
__device__ inline void rotate_around_center(const Point &center,
const float angle_cos,
const float angle_sin, Point *p) {
float new_x = (p->x - center.x) * angle_cos +
(p->y - center.y) * (-angle_sin) + center.x;
float new_y =
(p->x - center.x) * angle_sin + (p->y - center.y) * angle_cos + center.y;
p->set(new_x, new_y);
}
__device__ inline int point_cmp(const Point &a, const Point &b,
const Point &center) {
return atan2(a.y - center.y, a.x - center.x) >
atan2(b.y - center.y, b.x - center.x);
}
__device__ inline float box_overlap(const float *box_a, const float *box_b) {
// params box_a: [x, y, z, dx, dy, dz, heading]
// params box_b: [x, y, z, dx, dy, dz, heading]
float a_angle = box_a[6], b_angle = box_b[6];
float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2,
a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2;
float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half;
float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half;
float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half;
float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half;
Point center_a(box_a[0], box_a[1]);
Point center_b(box_b[0], box_b[1]);
Point box_a_corners[5];
box_a_corners[0].set(a_x1, a_y1);
box_a_corners[1].set(a_x2, a_y1);
box_a_corners[2].set(a_x2, a_y2);
box_a_corners[3].set(a_x1, a_y2);
Point box_b_corners[5];
box_b_corners[0].set(b_x1, b_y1);
box_b_corners[1].set(b_x2, b_y1);
box_b_corners[2].set(b_x2, b_y2);
box_b_corners[3].set(b_x1, b_y2);
// get oriented corners
float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);
float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
for (int k = 0; k < 4; k++) {
rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners + k);
rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners + k);
}
box_a_corners[4] = box_a_corners[0];
box_b_corners[4] = box_b_corners[0];
// get intersection of lines
Point cross_points[16];
Point poly_center;
int cnt = 0, flag = 0;
poly_center.set(0, 0);
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
flag = intersection(box_a_corners[i + 1], box_a_corners[i],
box_b_corners[j + 1], box_b_corners[j],
cross_points + cnt);
if (flag) {
poly_center = poly_center + cross_points[cnt];
cnt++;
}
}
}
// check corners
for (int k = 0; k < 4; k++) {
if (check_in_box2d(box_a, box_b_corners[k])) {
poly_center = poly_center + box_b_corners[k];
cross_points[cnt] = box_b_corners[k];
cnt++;
}
if (check_in_box2d(box_b, box_a_corners[k])) {
poly_center = poly_center + box_a_corners[k];
cross_points[cnt] = box_a_corners[k];
cnt++;
}
}
poly_center.x /= cnt;
poly_center.y /= cnt;
// sort the points of polygon
Point temp;
for (int j = 0; j < cnt - 1; j++) {
for (int i = 0; i < cnt - j - 1; i++) {
if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) {
temp = cross_points[i];
cross_points[i] = cross_points[i + 1];
cross_points[i + 1] = temp;
}
}
}
// get the overlap areas
float area = 0;
for (int k = 0; k < cnt - 1; k++) {
area += cross(cross_points[k] - cross_points[0],
cross_points[k + 1] - cross_points[0]);
}
return fabs(area) / 2.0;
}
__device__ inline float iou_bev(const float *box_a, const float *box_b) {
// params box_a: [x, y, z, dx, dy, dz, heading]
// params box_b: [x, y, z, dx, dy, dz, heading]
float sa = box_a[3] * box_a[4];
float sb = box_b[3] * box_b[4];
float s_overlap = box_overlap(box_a, box_b);
return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
}
__global__ void nms_kernel(const int num_bboxes, const int num_bboxes_for_nms,
const float nms_overlap_thresh,
const int decode_bboxes_dims, const float *bboxes,
const int *index, const int64_t *sorted_index,
int64_t *mask) {
// params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size =
fminf(num_bboxes_for_nms - row_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
const int col_size =
fminf(num_bboxes_for_nms - col_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];
if (threadIdx.x < col_size) {
int box_idx =
index[sorted_index[THREADS_PER_BLOCK_NMS * col_start + threadIdx.x]];
block_boxes[threadIdx.x * 7 + 0] = bboxes[box_idx * decode_bboxes_dims];
block_boxes[threadIdx.x * 7 + 1] = bboxes[box_idx * decode_bboxes_dims + 1];
block_boxes[threadIdx.x * 7 + 2] = bboxes[box_idx * decode_bboxes_dims + 2];
block_boxes[threadIdx.x * 7 + 3] = bboxes[box_idx * decode_bboxes_dims + 4];
block_boxes[threadIdx.x * 7 + 4] = bboxes[box_idx * decode_bboxes_dims + 3];
block_boxes[threadIdx.x * 7 + 5] = bboxes[box_idx * decode_bboxes_dims + 5];
block_boxes[threadIdx.x * 7 + 6] =
-bboxes[box_idx * decode_bboxes_dims + decode_bboxes_dims - 1] -
3.141592653589793 / 2;
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
const int act_box_idx = index[sorted_index[cur_box_idx]];
float cur_box[7];
cur_box[0] = bboxes[act_box_idx * decode_bboxes_dims];
cur_box[1] = bboxes[act_box_idx * decode_bboxes_dims + 1];
cur_box[2] = bboxes[act_box_idx * decode_bboxes_dims + 2];
cur_box[3] = bboxes[act_box_idx * decode_bboxes_dims + 4];
cur_box[4] = bboxes[act_box_idx * decode_bboxes_dims + 3];
cur_box[5] = bboxes[act_box_idx * decode_bboxes_dims + 5];
cur_box[6] =
-bboxes[act_box_idx * decode_bboxes_dims + decode_bboxes_dims - 1] -
3.141592653589793 / 2;
int i = 0;
int64_t t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (iou_bev(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
const int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS);
mask[cur_box_idx * col_blocks + col_start] = t;
}
}
void NmsLauncher(const cudaStream_t &stream, const float *bboxes,
const int *index, const int64_t *sorted_index,
const int num_bboxes, const int num_bboxes_for_nms,
const float nms_overlap_thresh, const int decode_bboxes_dims,
int64_t *mask) {
dim3 blocks(DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS),
DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS);
nms_kernel<<<blocks, threads, 0, stream>>>(
num_bboxes, num_bboxes_for_nms, nms_overlap_thresh, decode_bboxes_dims,
bboxes, index, sorted_index, mask);
}

View File

@@ -0,0 +1,201 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#if defined(WITH_GPU)
#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
#include "paddle/include/paddle/extension.h"
#else
#include "paddle/extension.h"
#endif
template <typename T, typename T_int>
bool hard_voxelize_cpu_kernel(
const T *points, const float point_cloud_range_x_min,
const float point_cloud_range_y_min, const float point_cloud_range_z_min,
const float voxel_size_x, const float voxel_size_y,
const float voxel_size_z, const int grid_size_x, const int grid_size_y,
const int grid_size_z, const int64_t num_points, const int num_point_dim,
const int max_num_points_in_voxel, const int max_voxels, T *voxels,
T_int *coords, T_int *num_points_per_voxel, T_int *grid_idx_to_voxel_idx,
T_int *num_voxels) {
std::fill(voxels,
voxels + max_voxels * max_num_points_in_voxel * num_point_dim,
static_cast<T>(0));
num_voxels[0] = 0;
int voxel_idx, grid_idx, curr_num_point;
int coord_x, coord_y, coord_z;
for (int point_idx = 0; point_idx < num_points; ++point_idx) {
coord_x = floor(
(points[point_idx * num_point_dim + 0] - point_cloud_range_x_min) /
voxel_size_x);
coord_y = floor(
(points[point_idx * num_point_dim + 1] - point_cloud_range_y_min) /
voxel_size_y);
coord_z = floor(
(points[point_idx * num_point_dim + 2] - point_cloud_range_z_min) /
voxel_size_z);
if (coord_x < 0 || coord_x > grid_size_x || coord_x == grid_size_x) {
continue;
}
if (coord_y < 0 || coord_y > grid_size_y || coord_y == grid_size_y) {
continue;
}
if (coord_z < 0 || coord_z > grid_size_z || coord_z == grid_size_z) {
continue;
}
grid_idx =
coord_z * grid_size_y * grid_size_x + coord_y * grid_size_x + coord_x;
voxel_idx = grid_idx_to_voxel_idx[grid_idx];
if (voxel_idx == -1) {
voxel_idx = num_voxels[0];
if (num_voxels[0] == max_voxels || num_voxels[0] > max_voxels) {
continue;
}
num_voxels[0]++;
grid_idx_to_voxel_idx[grid_idx] = voxel_idx;
coords[voxel_idx * 3 + 0] = coord_z;
coords[voxel_idx * 3 + 1] = coord_y;
coords[voxel_idx * 3 + 2] = coord_x;
}
curr_num_point = num_points_per_voxel[voxel_idx];
if (curr_num_point < max_num_points_in_voxel) {
for (int j = 0; j < num_point_dim; ++j) {
voxels[voxel_idx * max_num_points_in_voxel * num_point_dim +
curr_num_point * num_point_dim + j] =
points[point_idx * num_point_dim + j];
}
num_points_per_voxel[voxel_idx] = curr_num_point + 1;
}
}
return true;
}
std::vector<paddle::Tensor> hard_voxelize_cpu(
const paddle::Tensor &points, const std::vector<float> &voxel_size,
const std::vector<float> &point_cloud_range,
const int max_num_points_in_voxel, const int max_voxels) {
auto num_points = points.shape()[0];
auto num_point_dim = points.shape()[1];
const float voxel_size_x = voxel_size[0];
const float voxel_size_y = voxel_size[1];
const float voxel_size_z = voxel_size[2];
const float point_cloud_range_x_min = point_cloud_range[0];
const float point_cloud_range_y_min = point_cloud_range[1];
const float point_cloud_range_z_min = point_cloud_range[2];
int grid_size_x = static_cast<int>(
round((point_cloud_range[3] - point_cloud_range[0]) / voxel_size_x));
int grid_size_y = static_cast<int>(
round((point_cloud_range[4] - point_cloud_range[1]) / voxel_size_y));
int grid_size_z = static_cast<int>(
round((point_cloud_range[5] - point_cloud_range[2]) / voxel_size_z));
auto voxels =
paddle::empty({max_voxels, max_num_points_in_voxel, num_point_dim},
paddle::DataType::FLOAT32, paddle::CPUPlace());
auto coords = paddle::full({max_voxels, 3}, 0, paddle::DataType::INT32,
paddle::CPUPlace());
auto *coords_data = coords.data<int>();
auto num_points_per_voxel = paddle::full(
{max_voxels}, 0, paddle::DataType::INT32, paddle::CPUPlace());
auto *num_points_per_voxel_data = num_points_per_voxel.data<int>();
std::fill(num_points_per_voxel_data,
num_points_per_voxel_data + num_points_per_voxel.size(),
static_cast<int>(0));
auto num_voxels =
paddle::full({1}, 0, paddle::DataType::INT32, paddle::CPUPlace());
auto *num_voxels_data = num_voxels.data<int>();
auto grid_idx_to_voxel_idx =
paddle::full({grid_size_z, grid_size_y, grid_size_x}, -1,
paddle::DataType::INT32, paddle::CPUPlace());
auto *grid_idx_to_voxel_idx_data = grid_idx_to_voxel_idx.data<int>();
PD_DISPATCH_FLOATING_TYPES(
points.type(), "hard_voxelize_cpu_kernel", ([&] {
hard_voxelize_cpu_kernel<data_t, int>(
points.data<data_t>(), point_cloud_range_x_min,
point_cloud_range_y_min, point_cloud_range_z_min, voxel_size_x,
voxel_size_y, voxel_size_z, grid_size_x, grid_size_y, grid_size_z,
num_points, num_point_dim, max_num_points_in_voxel, max_voxels,
voxels.data<data_t>(), coords_data, num_points_per_voxel_data,
grid_idx_to_voxel_idx_data, num_voxels_data);
}));
return {voxels, coords, num_points_per_voxel, num_voxels};
}
#ifdef PADDLE_WITH_CUDA
std::vector<paddle::Tensor> hard_voxelize_cuda(
const paddle::Tensor &points, const std::vector<float> &voxel_size,
const std::vector<float> &point_cloud_range, int max_num_points_in_voxel,
int max_voxels);
#endif
std::vector<paddle::Tensor> hard_voxelize(
const paddle::Tensor &points, const std::vector<float> &voxel_size,
const std::vector<float> &point_cloud_range,
const int max_num_points_in_voxel, const int max_voxels) {
if (points.is_cpu()) {
return hard_voxelize_cpu(points, voxel_size, point_cloud_range,
max_num_points_in_voxel, max_voxels);
#ifdef PADDLE_WITH_CUDA
} else if (points.is_gpu() || points.is_gpu_pinned()) {
return hard_voxelize_cuda(points, voxel_size, point_cloud_range,
max_num_points_in_voxel, max_voxels);
#endif
} else {
PD_THROW(
"Unsupported device type for hard_voxelize "
"operator.");
}
}
std::vector<std::vector<int64_t>> HardInferShape(
std::vector<int64_t> points_shape, const std::vector<float> &voxel_size,
const std::vector<float> &point_cloud_range,
const int &max_num_points_in_voxel, const int &max_voxels) {
return {{max_voxels, max_num_points_in_voxel, points_shape[1]},
{max_voxels, 3},
{max_voxels},
{1}};
}
std::vector<paddle::DataType> HardInferDtype(paddle::DataType points_dtype) {
return {points_dtype, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32};
}
PD_BUILD_OP(hard_voxelize)
.Inputs({"POINTS"})
.Outputs({"VOXELS", "COORS", "NUM_POINTS_PER_VOXEL", "num_voxels"})
.SetKernelFn(PD_KERNEL(hard_voxelize))
.Attrs({"voxel_size: std::vector<float>",
"point_cloud_range: std::vector<float>",
"max_num_points_in_voxel: int", "max_voxels: int"})
.SetInferShapeFn(PD_INFER_SHAPE(HardInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(HardInferDtype));
#endif // WITH_GPU

View File

@@ -0,0 +1,351 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
#include "paddle/include/paddle/extension.h"
#else
#include "paddle/extension.h"
#endif
#define CHECK_INPUT_CUDA(x) \
PD_CHECK(x.is_gpu() || x.is_gpu_pinned(), #x " must be a GPU Tensor.")
#define CUDA_KERNEL_LOOP(i, n) \
for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T, typename T_int>
__global__ void init_num_point_grid(
const T *points, const float point_cloud_range_x_min,
const float point_cloud_range_y_min, const float point_cloud_range_z_min,
const float voxel_size_x, const float voxel_size_y,
const float voxel_size_z, const int grid_size_x, const int grid_size_y,
const int grid_size_z, const int64_t num_points, const int num_point_dim,
T_int *num_points_in_grid, int *points_valid) {
int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (point_idx > num_points || point_idx == num_points) {
return;
}
int coord_x =
floor((points[point_idx * num_point_dim + 0] - point_cloud_range_x_min) /
voxel_size_x);
int coord_y =
floor((points[point_idx * num_point_dim + 1] - point_cloud_range_y_min) /
voxel_size_y);
int coord_z =
floor((points[point_idx * num_point_dim + 2] - point_cloud_range_z_min) /
voxel_size_z);
if (coord_x < 0 || coord_x > grid_size_x || coord_x == grid_size_x) {
return;
}
if (coord_y < 0 || coord_y > grid_size_y || coord_y == grid_size_y) {
return;
}
if (coord_z < 0 || coord_z > grid_size_z || coord_z == grid_size_z) {
return;
}
int grid_idx =
coord_z * grid_size_y * grid_size_x + coord_y * grid_size_x + coord_x;
num_points_in_grid[grid_idx] = 0;
points_valid[grid_idx] = num_points;
}
template <typename T, typename T_int>
__global__ void map_point_to_grid_kernel(
const T *points, const float point_cloud_range_x_min,
const float point_cloud_range_y_min, const float point_cloud_range_z_min,
const float voxel_size_x, const float voxel_size_y,
const float voxel_size_z, const int grid_size_x, const int grid_size_y,
const int grid_size_z, const int64_t num_points, const int num_point_dim,
const int max_num_points_in_voxel, T_int *points_to_grid_idx,
T_int *points_to_num_idx, T_int *num_points_in_grid, int *points_valid) {
int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (point_idx > num_points || point_idx == num_points) {
return;
}
int coord_x =
floor((points[point_idx * num_point_dim + 0] - point_cloud_range_x_min) /
voxel_size_x);
int coord_y =
floor((points[point_idx * num_point_dim + 1] - point_cloud_range_y_min) /
voxel_size_y);
int coord_z =
floor((points[point_idx * num_point_dim + 2] - point_cloud_range_z_min) /
voxel_size_z);
if (coord_x < 0 || coord_x > grid_size_x || coord_x == grid_size_x) {
return;
}
if (coord_y < 0 || coord_y > grid_size_y || coord_y == grid_size_y) {
return;
}
if (coord_z < 0 || coord_z > grid_size_z || coord_z == grid_size_z) {
return;
}
int grid_idx =
coord_z * grid_size_y * grid_size_x + coord_y * grid_size_x + coord_x;
T_int num = atomicAdd(num_points_in_grid + grid_idx, 1);
if (num < max_num_points_in_voxel) {
points_to_num_idx[point_idx] = num;
points_to_grid_idx[point_idx] = grid_idx;
atomicMin(points_valid + grid_idx, static_cast<int>(point_idx));
}
}
template <typename T_int>
__global__ void update_points_flag(const int *points_valid,
const T_int *points_to_grid_idx,
const int num_points, int *points_flag) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < num_points; i += gridDim.x * blockDim.x) {
T_int grid_idx = points_to_grid_idx[i];
if (grid_idx >= 0) {
int id = points_valid[grid_idx];
if (id != num_points && id == i) {
points_flag[i] = 1;
}
}
}
}
template <typename T_int>
__global__ void get_voxel_idx_kernel(const int *points_flag,
const T_int *points_to_grid_idx,
const int *points_flag_prefix_sum,
const int num_points, const int max_voxels,
T_int *num_voxels,
T_int *grid_idx_to_voxel_idx) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
for (int i = tid; i < num_points; i += gridDim.x * blockDim.x) {
if (points_flag[i] == 1) {
T_int grid_idx = points_to_grid_idx[i];
int num = points_flag_prefix_sum[i];
if (num < max_voxels) {
grid_idx_to_voxel_idx[grid_idx] = num;
}
}
if (i == num_points - 1) {
int num = points_flag_prefix_sum[i] + points_flag[i];
if (num < max_voxels) {
num_voxels[0] = num;
} else {
num_voxels[0] = max_voxels;
}
}
}
}
template <typename T>
__global__ void init_voxels_kernel(const int64_t num, T *voxels) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx > num || idx == num) {
return;
}
voxels[idx] = static_cast<T>(0);
}
template <typename T, typename T_int>
__global__ void assign_voxels_kernel(
const T *points, const T_int *points_to_grid_idx,
const T_int *points_to_num_idx, const T_int *grid_idx_to_voxel_idx,
const int64_t num_points, const int num_point_dim,
const int max_num_points_in_voxel, T *voxels) {
int64_t point_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (point_idx > num_points || point_idx == num_points) {
return;
}
T_int grid_idx = points_to_grid_idx[point_idx];
T_int num_idx = points_to_num_idx[point_idx];
if (grid_idx > -1 && num_idx > -1) {
T_int voxel_idx = grid_idx_to_voxel_idx[grid_idx];
if (voxel_idx > -1) {
for (int64_t i = 0; i < num_point_dim; ++i) {
voxels[voxel_idx * max_num_points_in_voxel * num_point_dim +
num_idx * num_point_dim + i] =
points[point_idx * num_point_dim + i];
}
}
}
}
template <typename T, typename T_int>
__global__ void assign_coords_kernel(const T_int *grid_idx_to_voxel_idx,
const T_int *num_points_in_grid,
const int num_grids, const int grid_size_x,
const int grid_size_y,
const int grid_size_z,
const int max_num_points_in_voxel,
T *coords, T *num_points_per_voxel) {
int64_t grid_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (grid_idx > num_grids || grid_idx == num_grids) {
return;
}
T_int voxel_idx = grid_idx_to_voxel_idx[grid_idx];
if (voxel_idx > -1) {
T_int coord_z = grid_idx / grid_size_x / grid_size_y;
T_int coord_y =
(grid_idx - coord_z * grid_size_x * grid_size_y) / grid_size_x;
T_int coord_x =
grid_idx - coord_z * grid_size_x * grid_size_y - coord_y * grid_size_x;
coords[voxel_idx * 3 + 0] = coord_z;
coords[voxel_idx * 3 + 1] = coord_y;
coords[voxel_idx * 3 + 2] = coord_x;
num_points_per_voxel[voxel_idx] =
min(num_points_in_grid[grid_idx], max_num_points_in_voxel);
}
}
std::vector<paddle::Tensor> hard_voxelize_cuda(
const paddle::Tensor &points, const std::vector<float> &voxel_size,
const std::vector<float> &point_cloud_range, int max_num_points_in_voxel,
int max_voxels) {
// check device
CHECK_INPUT_CUDA(points);
int64_t num_points = points.shape()[0];
int64_t num_point_dim = points.shape()[1];
const float voxel_size_x = voxel_size[0];
const float voxel_size_y = voxel_size[1];
const float voxel_size_z = voxel_size[2];
const float point_cloud_range_x_min = point_cloud_range[0];
const float point_cloud_range_y_min = point_cloud_range[1];
const float point_cloud_range_z_min = point_cloud_range[2];
int grid_size_x = static_cast<int>(
round((point_cloud_range[3] - point_cloud_range[0]) / voxel_size_x));
int grid_size_y = static_cast<int>(
round((point_cloud_range[4] - point_cloud_range[1]) / voxel_size_y));
int grid_size_z = static_cast<int>(
round((point_cloud_range[5] - point_cloud_range[2]) / voxel_size_z));
int num_grids = grid_size_x * grid_size_y * grid_size_z;
auto voxels =
paddle::empty({max_voxels, max_num_points_in_voxel, num_point_dim},
paddle::DataType::FLOAT32, paddle::GPUPlace());
auto coords = paddle::full({max_voxels, 3}, 0, paddle::DataType::INT32,
paddle::GPUPlace());
auto *coords_data = coords.data<int>();
auto num_points_per_voxel = paddle::full(
{max_voxels}, 0, paddle::DataType::INT32, paddle::GPUPlace());
auto *num_points_per_voxel_data = num_points_per_voxel.data<int>();
auto points_to_grid_idx = paddle::full(
{num_points}, -1, paddle::DataType::INT32, paddle::GPUPlace());
auto *points_to_grid_idx_data = points_to_grid_idx.data<int>();
auto points_to_num_idx = paddle::full(
{num_points}, -1, paddle::DataType::INT32, paddle::GPUPlace());
auto *points_to_num_idx_data = points_to_num_idx.data<int>();
auto num_points_in_grid =
paddle::empty({grid_size_z, grid_size_y, grid_size_x},
paddle::DataType::INT32, paddle::GPUPlace());
auto *num_points_in_grid_data = num_points_in_grid.data<int>();
auto grid_idx_to_voxel_idx =
paddle::full({grid_size_z, grid_size_y, grid_size_x}, -1,
paddle::DataType::INT32, paddle::GPUPlace());
auto *grid_idx_to_voxel_idx_data = grid_idx_to_voxel_idx.data<int>();
auto num_voxels =
paddle::full({1}, 0, paddle::DataType::INT32, paddle::GPUPlace());
auto *num_voxels_data = num_voxels.data<int>();
auto points_valid =
paddle::empty({grid_size_z, grid_size_y, grid_size_x},
paddle::DataType::INT32, paddle::GPUPlace());
int *points_valid_data = points_valid.data<int>();
auto points_flag = paddle::full({num_points}, 0, paddle::DataType::INT32,
paddle::GPUPlace());
// 1. Find the grid index for each point, compute the
// number of points in each grid
int64_t threads = 512;
int64_t blocks = (num_points + threads - 1) / threads;
PD_DISPATCH_FLOATING_TYPES(
points.type(), "init_num_point_grid", ([&] {
init_num_point_grid<data_t, int>
<<<blocks, threads, 0, points.stream()>>>(
points.data<data_t>(), point_cloud_range_x_min,
point_cloud_range_y_min, point_cloud_range_z_min, voxel_size_x,
voxel_size_y, voxel_size_z, grid_size_x, grid_size_y,
grid_size_z, num_points, num_point_dim, num_points_in_grid_data,
points_valid_data);
}));
PD_DISPATCH_FLOATING_TYPES(
points.type(), "map_point_to_grid_kernel", ([&] {
map_point_to_grid_kernel<data_t, int>
<<<blocks, threads, 0, points.stream()>>>(
points.data<data_t>(), point_cloud_range_x_min,
point_cloud_range_y_min, point_cloud_range_z_min, voxel_size_x,
voxel_size_y, voxel_size_z, grid_size_x, grid_size_y,
grid_size_z, num_points, num_point_dim, max_num_points_in_voxel,
points_to_grid_idx_data, points_to_num_idx_data,
num_points_in_grid_data, points_valid_data);
}));
// 2. Find the number of non-zero voxels
int *points_flag_data = points_flag.data<int>();
threads = 512;
blocks = (num_points + threads - 1) / threads;
update_points_flag<int><<<blocks, threads, 0, points.stream()>>>(
points_valid_data, points_to_grid_idx_data, num_points, points_flag_data);
auto points_flag_prefix_sum =
paddle::experimental::cumsum(points_flag, 0, false, true, false);
int *points_flag_prefix_sum_data = points_flag_prefix_sum.data<int>();
get_voxel_idx_kernel<int><<<blocks, threads, 0, points.stream()>>>(
points_flag_data, points_to_grid_idx_data, points_flag_prefix_sum_data,
num_points, max_voxels, num_voxels_data, grid_idx_to_voxel_idx_data);
// 3. Store points to voxels coords and num_points_per_voxel
int64_t num = max_voxels * max_num_points_in_voxel * num_point_dim;
threads = 512;
blocks = (num + threads - 1) / threads;
PD_DISPATCH_FLOATING_TYPES(points.type(), "init_voxels_kernel", ([&] {
init_voxels_kernel<data_t>
<<<blocks, threads, 0, points.stream()>>>(
num, voxels.data<data_t>());
}));
threads = 512;
blocks = (num_points + threads - 1) / threads;
PD_DISPATCH_FLOATING_TYPES(
points.type(), "assign_voxels_kernel", ([&] {
assign_voxels_kernel<data_t, int>
<<<blocks, threads, 0, points.stream()>>>(
points.data<data_t>(), points_to_grid_idx_data,
points_to_num_idx_data, grid_idx_to_voxel_idx_data, num_points,
num_point_dim, max_num_points_in_voxel, voxels.data<data_t>());
}));
// 4. Store coords, num_points_per_voxel
blocks = (num_grids + threads - 1) / threads;
assign_coords_kernel<int><<<blocks, threads, 0, points.stream()>>>(
grid_idx_to_voxel_idx_data, num_points_in_grid_data, num_grids,
grid_size_x, grid_size_y, grid_size_z, max_num_points_in_voxel,
coords_data, num_points_per_voxel_data);
return {voxels, coords, num_points_per_voxel, num_voxels};
}

View File

@@ -103,6 +103,8 @@ struct PaddleBackendOption {
/// Collect shape for model while enable_trt is true
bool collect_trt_shape = false;
/// Collect shape for model by device (for some custom ops)
bool collect_trt_shape_by_device = false;
/// Cache input shape for mkldnn while the input data will change dynamiclly
int mkldnn_cache_size = -1;
/// initialize memory size(MB) for GPU

View File

@@ -45,8 +45,12 @@ void BindPaddleOption(pybind11::module& m) {
&PaddleBackendOption::enable_memory_optimize)
.def_readwrite("switch_ir_debug", &PaddleBackendOption::switch_ir_debug)
.def_readwrite("ipu_option", &PaddleBackendOption::ipu_option)
.def_readwrite("xpu_option", &PaddleBackendOption::xpu_option)
.def_readwrite("trt_option", &PaddleBackendOption::trt_option)
.def_readwrite("collect_trt_shape",
&PaddleBackendOption::collect_trt_shape)
.def_readwrite("collect_trt_shape_by_device",
&PaddleBackendOption::collect_trt_shape_by_device)
.def_readwrite("mkldnn_cache_size",
&PaddleBackendOption::mkldnn_cache_size)
.def_readwrite("gpu_mem_init_size",

View File

@@ -256,6 +256,12 @@ bool PaddleBackend::InitFromPaddle(const std::string& model,
} else {
analysis_config.SetModel(model, params);
}
if (option.collect_trt_shape_by_device) {
if (option.device == Device::GPU) {
analysis_config.EnableUseGpu(option.gpu_mem_init_size, option.device_id,
paddle_infer::PrecisionType::kFloat32);
}
}
analysis_config.CollectShapeRangeInfo(shape_range_info);
auto predictor_tmp = paddle_infer::CreatePredictor(analysis_config);
std::map<std::string, std::vector<int>> max_shape;

View File

@@ -23,5 +23,6 @@ void BindPerception(pybind11::module& m) {
auto perception_module =
m.def_submodule("perception", "3D object perception models.");
BindSmoke(perception_module);
BindPetr(perception_module);
}
} // namespace fastdeploy

View File

@@ -58,17 +58,12 @@ setup_configs["LIBRARY_NAME"] = PACKAGE_NAME
setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main"
# Backend options
setup_configs["ENABLE_TVM_BACKEND"] = os.getenv("ENABLE_TVM_BACKEND", "OFF")
setup_configs["ENABLE_RKNPU2_BACKEND"] = os.getenv("ENABLE_RKNPU2_BACKEND",
"OFF")
setup_configs["ENABLE_SOPHGO_BACKEND"] = os.getenv("ENABLE_SOPHGO_BACKEND",
"OFF")
setup_configs["ENABLE_RKNPU2_BACKEND"] = os.getenv("ENABLE_RKNPU2_BACKEND", "OFF")
setup_configs["ENABLE_SOPHGO_BACKEND"] = os.getenv("ENABLE_SOPHGO_BACKEND", "OFF")
setup_configs["ENABLE_ORT_BACKEND"] = os.getenv("ENABLE_ORT_BACKEND", "OFF")
setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND",
"OFF")
setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND",
"OFF")
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND",
"OFF")
setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND", "OFF")
setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND", "OFF")
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", "OFF")
setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF")
setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF")
setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF")
@@ -88,15 +83,16 @@ setup_configs["WITH_KUNLUNXIN"] = os.getenv("WITH_KUNLUNXIN", "OFF")
setup_configs["RKNN2_TARGET_SOC"] = os.getenv("RKNN2_TARGET_SOC", "")
# Custom deps settings
setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED")
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY",
"/usr/local/cuda")
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda")
setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "")
setup_configs["ORT_DIRECTORY"] = os.getenv("ORT_DIRECTORY", "")
setup_configs["PADDLEINFERENCE_DIRECTORY"] = os.getenv(
"PADDLEINFERENCE_DIRECTORY", "")
setup_configs["PADDLEINFERENCE_VERSION"] = os.getenv("PADDLEINFERENCE_VERSION",
"")
setup_configs["PADDLEINFERENCE_DIRECTORY"] = os.getenv("PADDLEINFERENCE_DIRECTORY", "")
setup_configs["PADDLEINFERENCE_VERSION"] = os.getenv("PADDLEINFERENCE_VERSION", "")
setup_configs["PADDLEINFERENCE_URL"] = os.getenv("PADDLEINFERENCE_URL", "")
setup_configs["PADDLEINFERENCE_API_COMPAT_2_4_x"] = os.getenv("PADDLEINFERENCE_API_COMPAT_2_4_x", "OFF")
setup_configs["PADDLEINFERENCE_API_COMPAT_2_5_x"] = os.getenv("PADDLEINFERENCE_API_COMPAT_2_5_x", "OFF")
setup_configs["PADDLEINFERENCE_API_COMPAT_DEV"] = os.getenv("PADDLEINFERENCE_API_COMPAT_DEV", "OFF")
setup_configs["PADDLEINFERENCE_API_CUSTOM_OP"] = os.getenv("PADDLEINFERENCE_API_CUSTOM_OP", "OFF")
setup_configs["PADDLE2ONNX_URL"] = os.getenv("PADDLE2ONNX_URL", "")
setup_configs["PADDLELITE_URL"] = os.getenv("PADDLELITE_URL", "")
# Other settings

View File

@@ -5,15 +5,20 @@ set +x
# -------------------------------------------------------------------------------
# readonly global variables
# -------------------------------------------------------------------------------
readonly ROOT_PATH=$(pwd)
readonly BUILD_ROOT=build/Linux
readonly BUILD_DIR="${BUILD_ROOT}/x86_64_gpu"
readonly PADDLEINFERENCE_DIRECTORY=$1
readonly PADDLEINFERENCE_VERSION=$2
ROOT_PATH=$(pwd)
BUILD_ROOT=build/Linux
BUILD_DIR="${BUILD_ROOT}/x86_64_gpu"
PADDLEINFERENCE_DIRECTORY=$1
PADDLEINFERENCE_VERSION=$2
PADDLEINFERENCE_API_CUSTOM_OP=$3
BUILD_WITH_CUSTOM_PADDLE='OFF'
if [[ "$PADDLEINFERENCE_DIRECTORY" != "" ]]; then
if [[ -d "$1" ]]; then
BUILD_WITH_CUSTOM_PADDLE='ON'
else
if [[ "$1" == "ON" ]]; then
PADDLEINFERENCE_API_CUSTOM_OP='ON'
fi
fi
# -------------------------------------------------------------------------------
@@ -71,6 +76,7 @@ __build_fastdeploy_linux_x86_64_gpu_shared() {
-DENABLE_VISION=ON \
-DENABLE_BENCHMARK=ON \
-DBUILD_EXAMPLES=OFF \
-DPADDLEINFERENCE_API_CUSTOM_OP=${PADDLEINFERENCE_API_CUSTOM_OP:-"OFF"} \
-DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \
-Wno-dev ../../.. && make -j8 && make install
@@ -91,6 +97,7 @@ __build_fastdeploy_linux_x86_64_gpu_shared_custom_paddle() {
-DENABLE_PADDLE_BACKEND=ON \
-DPADDLEINFERENCE_DIRECTORY=${PADDLEINFERENCE_DIRECTORY} \
-DPADDLEINFERENCE_VERSION=${PADDLEINFERENCE_VERSION} \
-DPADDLEINFERENCE_API_CUSTOM_OP=${PADDLEINFERENCE_API_CUSTOM_OP:-"OFF"} \
-DENABLE_OPENVINO_BACKEND=ON \
-DENABLE_PADDLE2ONNX=ON \
-DENABLE_VISION=ON \
@@ -118,5 +125,5 @@ main() {
main
# Usage:
# ./scripts/linux/build_linux_x86_64_cpp_gpu.sh
# ./scripts/linux/build_linux_x86_64_cpp_gpu.sh paddle_inference-linux-x64-gpu-trt8.5.2.2-mkl-avx-2.4.2 paddle2.4.2
# ./scripts/linux/build_linux_x86_64_cpp_gpu_with_benchmark.sh
# ./scripts/linux/build_linux_x86_64_cpp_gpu_with_benchmark.sh $PADDLEINFERENCE_DIRECTORY $PADDLEINFERENCE_VERSION