[Sync][Internal] sync some internal paddle3d codes (#2108)

This commit is contained in:
DefTruth
2023-07-13 22:06:28 +08:00
committed by GitHub
parent 77ee48f9b8
commit 681ccc4c24
30 changed files with 2517 additions and 45 deletions

View File

@@ -462,7 +462,7 @@ endif()
if(ENABLE_ENCRYPTION) if(ENABLE_ENCRYPTION)
add_definitions(-DENABLE_ENCRYPTION) add_definitions(-DENABLE_ENCRYPTION)
list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ENCRYPTION_SRCS}) list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_ENCRYPTION_SRCS})
include(${PROJECT_SOURCE_DIR}/cmake/gflags.cmake) # include(${PROJECT_SOURCE_DIR}/cmake/gflags.cmake)
include(${PROJECT_SOURCE_DIR}/cmake/openssl.cmake) include(${PROJECT_SOURCE_DIR}/cmake/openssl.cmake)
list(APPEND DEPEND_LIBS ${OPENSSL_LIBRARIES}) list(APPEND DEPEND_LIBS ${OPENSSL_LIBRARIES})
endif() endif()

View File

@@ -21,6 +21,7 @@ add_executable(benchmark_structurev2_table ${PROJECT_SOURCE_DIR}/benchmark_struc
add_executable(benchmark_structurev2_layout ${PROJECT_SOURCE_DIR}/benchmark_structurev2_layout.cc) add_executable(benchmark_structurev2_layout ${PROJECT_SOURCE_DIR}/benchmark_structurev2_layout.cc)
add_executable(benchmark_ppshituv2_rec ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_rec.cc) add_executable(benchmark_ppshituv2_rec ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_rec.cc)
add_executable(benchmark_ppshituv2_det ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_det.cc) add_executable(benchmark_ppshituv2_det ${PROJECT_SOURCE_DIR}/benchmark_ppshituv2_det.cc)
add_executable(benchmark_pp3d_cadnn ${PROJECT_SOURCE_DIR}/benchmark_pp3d_cadnn.cc)
add_executable(benchmark_pp3d_centerpoint ${PROJECT_SOURCE_DIR}/benchmark_pp3d_centerpoint.cc) add_executable(benchmark_pp3d_centerpoint ${PROJECT_SOURCE_DIR}/benchmark_pp3d_centerpoint.cc)
if(UNIX AND (NOT APPLE) AND (NOT ANDROID)) if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
@@ -34,6 +35,7 @@ if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_pp3d_cadnn ${FASTDEPLOY_LIBS} gflags pthread)
target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags pthread) target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags pthread)
else() else()
target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark ${FASTDEPLOY_LIBS} gflags)
@@ -46,6 +48,7 @@ else()
target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_structurev2_layout ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppshituv2_rec ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_ppshituv2_det ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_pp3d_cadnn ${FASTDEPLOY_LIBS} gflags)
target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags) target_link_libraries(benchmark_pp3d_centerpoint ${FASTDEPLOY_LIBS} gflags)
endif() endif()
# only for Android ADB test # only for Android ADB test

View File

@@ -41,9 +41,17 @@ fi
# PP-ShiTuV2 # PP-ShiTuV2
./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH ./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH
./benchmark_ppshituv2_det --model PP-ShiTuv2-det --image ppdet_det_img.jpg --config_path $CONFIG_PATH
# PP-StructureV2 # PP-StructureV2
./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH ./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH
./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH ./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH
./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --trt_shapes "1,512:1,512:1,512:1,512,4:1,512,4:1,512,4:1,512:1,512:1,512:1,512:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --config_path $CONFIG_PATH
# Paddle3D
./benchmark --model PETRv1_v99 --shapes "1,6,3,320,800:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark --model PETRv2_v99 --shapes "1,12,3,320,800:1,12,4,4:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH
set +x set +x

View File

@@ -43,9 +43,17 @@ fi
# PP-ShiTuV2 # PP-ShiTuV2
./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH ./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH
./benchmark_ppshituv2_det --model PP-ShiTuv2-det --image ppdet_det_img.jpg --config_path $CONFIG_PATH
# PP-StructureV2 # PP-StructureV2
./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH ./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH
./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH ./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH
./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --trt_shapes "1,512:1,512:1,512:1,512,4:1,512,4:1,512,4:1,512:1,512:1,512:1,512:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --collect_trt_shape_by_custom_tensor_value --collect_trt_shape_by_device --config_path $CONFIG_PATH
# Paddle3D
./benchmark --model PETRv1_v99 --shapes "1,6,3,320,800:1,6,4,4" --trt_shapes "1,6,3,320,800:1,6,3,320,800:1,6,3,320,800:1,6,4,4:1,6,4,4:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark --model PETRv2_v99 --shapes "1,12,3,320,800:1,12,4,4:1,12" --trt_shapes "1,12,3,320,800:1,12,3,320,800:1,12,3,320,800:1,12,4,4:1,12,4,4:1,12,4,4:1,12:1,12:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --config_path $CONFIG_PATH
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH
set +x set +x

View File

@@ -0,0 +1,81 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "flags.h"
#include "macros.h"
#include "option.h"
namespace vision = fastdeploy::vision;
namespace benchmark = fastdeploy::benchmark;
int main(int argc, char* argv[]) {
#if defined(ENABLE_BENCHMARK) && defined(ENABLE_VISION)
// Initialization
auto option = fastdeploy::RuntimeOption();
if (!CreateRuntimeOption(&option, argc, argv, true)) {
return -1;
}
auto im = cv::imread(FLAGS_image);
std::unordered_map<std::string, std::string> config_info;
benchmark::ResultManager::LoadBenchmarkConfig(FLAGS_config_path,
&config_info);
std::string model_name, params_name, config_name;
auto model_format = fastdeploy::ModelFormat::PADDLE;
if (!UpdateModelResourceName(&model_name, &params_name, &config_name,
&model_format, config_info, false)) {
return -1;
}
auto model_file = FLAGS_model + sep + model_name;
auto params_file = FLAGS_model + sep + params_name;
std::vector<float> cam_data{7.183351e+02, 0.000000e+00, 6.003891e+02,
4.450382e+01, 0.000000e+00, 7.183351e+02,
1.815122e+02, -5.951107e-01, 0.000000e+00,
0.000000e+00, 1.000000e+00, 2.616315e-03};
std::vector<float> lidar_data = {
0.0048523, -0.9999298, -0.01081266, -0.00711321,
-0.00302069, 0.01079808, -0.99993706, -0.06176636,
0.99998367, 0.00488465, -0.00296808, -0.26739058,
0., 0., 0., 1.};
if (config_info["backend"] == "paddle_trt") {
option.paddle_infer_option.collect_trt_shape = true;
option.paddle_infer_option.collect_trt_shape_by_device = true;
option.paddle_infer_option.trt_min_subgraph_size = 12;
option.paddle_infer_option.DisableTrtOps({"squeeze2"});
option.trt_option.max_batch_size = 1;
}
if (config_info["backend"] == "paddle_trt" ||
config_info["backend"] == "trt") {
// use custom data to perform collect shapes.
option.trt_option.SetShape("images", {1, 3, 375, 1242},
{1, 3, 375, 1242}, {1, 3, 375, 1242});
option.trt_option.SetShape("trans_lidar_to_cam", {1, 4, 4},
{1, 4, 4}, {1, 4, 4});
option.trt_option.SetShape("trans_cam_to_img", {1, 3, 4},
{1, 3, 4}, {1, 3, 4});
std::vector<float> image_data;
image_data.assign(im.data, im.data + 1*3*375*1242);
option.trt_option.SetInputData("trans_lidar_to_cam", lidar_data);
option.trt_option.SetInputData("trans_cam_to_img", cam_data);
option.trt_option.SetInputData("images", image_data);
}
auto model_cadnn = vision::perception::Caddn(
model_file, params_file, "", option, model_format);
vision::PerceptionResult res;
// Run profiling
BENCHMARK_MODEL(model_cadnn, model_cadnn.Predict(im, cam_data, lidar_data, &res))
std::cout << res.Str() << std::endl;
#endif
return 0;
}

View File

@@ -93,7 +93,7 @@ int main(int argc, char* argv[]) {
vision::PerceptionResult res; vision::PerceptionResult res;
// Run profiling // Run profiling
BENCHMARK_MODEL(model_centerpoint, model_centerpoint.Predict(point_dir, &res)) BENCHMARK_MODEL(model_centerpoint, model_centerpoint.Predict(point_dir, &res))
// std::cout << res.Str() << std::endl; std::cout << res.Str() << std::endl;
#endif #endif
return 0; return 0;

View File

@@ -41,9 +41,17 @@ fi
# PP-ShiTuV2 # PP-ShiTuV2
./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH ./benchmark_ppshituv2_rec --model PP-ShiTuv2-rec --image ppshituv2_wangzai.png --config_path $CONFIG_PATH
./benchmark_ppshituv2_det --model PP-ShiTuv2-det --image ppdet_det_img.jpg --config_path $CONFIG_PATH
# PP-StructureV2 # PP-StructureV2
./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH ./benchmark_structurev2_layout --model PP-Structurev2-layout --image structurev2_layout_val_0002.jpg --config_path $CONFIG_PATH
./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH ./benchmark_structurev2_table --model PP-Structurev2-SLANet --image structurev2_table.jpg --table_char_dict_path table_structure_dict_ch.txt --config_path $CONFIG_PATH
./benchmark --model PP-Structurev2-vi-layoutxlm --shapes "1,512:1,512,4:1,512:1,512" --names "x_0:x_1:x_2:x_3" --dtypes "INT64:INT64:INT64:INT64" --disable_mkldnn --custom_tensor_value 0.2 --config_path $CONFIG_PATH
# Paddle3D
./benchmark --model PETRv1_v99 --config_path $CONFIG_PATH --shapes "1,6,3,320,800:1,6,4,4" --names "images:img2lidars" --dtypes "FP32:FP32" --disable_mkldnn --warmup 5 --repeat 20
./benchmark --model PETRv2_v99 --config_path $CONFIG_PATH --shapes "1,12,3,320,800:1,12,4,4:1,12" --names "images:img2lidars:timestamps" --dtypes "FP32:FP32:FP32" --disable_mkldnn --warmup 5 --repeat 20
./benchmark_pp3d_centerpoint --model CenterPoint-Pillars-02Voxel --image paddle3d_centerpoint_n008_LIDAR_TOP__1533151603547590.pcd.bin --config_path $CONFIG_PATH
./benchmark_pp3d_cadnn --model CADNN_OCRNet-HRNetW18 --image paddle3d_cadnn_kitti_000780.png --config_path $CONFIG_PATH
set +x set +x

View File

@@ -119,3 +119,4 @@ PD_BUILD_OP(centerpoint_postprocess)
.SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::PostProcessInferDtype)); .SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::PostProcessInferDtype));
#endif // WITH_GPU #endif // WITH_GPU

View File

@@ -220,7 +220,7 @@ std::vector<paddle::Tensor> postprocess_gpu(
// nms // nms
// in NmsLauncher, rot = - theta - pi / 2 // in NmsLauncher, rot = - theta - pi / 2
const int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS); int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS);
auto nms_mask = paddle::empty({num_bboxes_for_nms * col_blocks}, auto nms_mask = paddle::empty({num_bboxes_for_nms * col_blocks},
paddle::DataType::INT64, paddle::GPUPlace()); paddle::DataType::INT64, paddle::GPUPlace());
int64_t *nms_mask_data = nms_mask.data<int64_t>(); int64_t *nms_mask_data = nms_mask.data<int64_t>();

View File

@@ -0,0 +1,94 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(WITH_GPU)
#include "grid_sample_3d.h"
#include <vector>
#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
#include "paddle/include/paddle/extension.h"
#else
#include "paddle/extension.h"
#endif
namespace fastdeploy {
namespace paddle_custom_ops {
std::vector<paddle::Tensor> GridSample3DCUDAForward(
const paddle::Tensor& x, const paddle::Tensor& grid,
const std::string& mode, const std::string& padding_mode,
bool align_corners);
std::vector<paddle::Tensor> GridSample3DForward(const paddle::Tensor& x,
const paddle::Tensor& grid,
const std::string& mode,
const std::string& padding_mode,
bool align_corners) {
return GridSample3DCUDAForward(x, grid, mode, padding_mode, align_corners);
}
std::vector<paddle::Tensor> GridSample3DCUDABackward(
const paddle::Tensor& x, const paddle::Tensor& grid,
const paddle::Tensor& grad_out, const std::string& mode,
const std::string& padding_mode, bool align_corners);
std::vector<paddle::Tensor> GridSample3DBackward(
const paddle::Tensor& x, const paddle::Tensor& grid,
const paddle::Tensor& grad_out, const std::string& mode,
const std::string& padding_mode, bool align_corners) {
return GridSample3DCUDABackward(x, grid, grad_out, mode, padding_mode,
align_corners);
}
std::vector<std::vector<int64_t>> GridSample3DInferShape(
std::vector<int64_t> x_shape, std::vector<int64_t> grid_shape) {
return {
{x_shape[0], x_shape[1], grid_shape[1], grid_shape[2], grid_shape[3]}};
}
std::vector<std::vector<int64_t>> GridSample3DInferBackShape(
std::vector<int64_t> x_shape, std::vector<int64_t> grid_shape) {
return {x_shape};
}
std::vector<paddle::DataType> GridSample3DInferDtype(
paddle::DataType x_dtype, paddle::DataType grid_dtype) {
return {x_dtype};
}
} // namespace fastdeploy
} // namespace paddle_custom_ops
PD_BUILD_OP(grid_sample_3d)
.Inputs({"x", "grid"})
.Attrs({"mode: std::string", "padding_mode: std::string",
"align_corners: bool"})
.Outputs({"out"})
.SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::GridSample3DForward))
.SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::GridSample3DInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::GridSample3DInferDtype));
PD_BUILD_GRAD_OP(grid_sample_3d)
.Inputs({"x", "grid", paddle::Grad("out")})
.Attrs({"mode: std::string", "padding_mode: std::string",
"align_corners: bool"})
.Outputs({paddle::Grad("x")})
.SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::GridSample3DBackward))
.SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::GridSample3DInferBackShape));
#endif

View File

@@ -0,0 +1,657 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda.h>
#include "grid_sample_3d.h"
#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
#include "paddle/include/paddle/extension.h"
#else
#include "paddle/extension.h"
#endif
namespace fastdeploy {
namespace paddle_custom_ops {
#define CHECK_INPUT_GPU(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
static __forceinline__ __device__ bool InBounds3D(int64_t d, int64_t h,
int64_t w, int64_t D,
int64_t H, int64_t W) {
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
}
#define CUDA_KERNEL_LOOP_TYPE(i, n, index_type) \
index_type _i_n_d_e_x = blockIdx.x * blockDim.x + threadIdx.x; \
for (index_type i = _i_n_d_e_x; _i_n_d_e_x < (n); \
_i_n_d_e_x += blockDim.x * gridDim.x, i = _i_n_d_e_x)
#define CUDA_KERNEL_LOOP(i, n) CUDA_KERNEL_LOOP_TYPE(i, n, int)
template <typename T>
static __forceinline__ __device__ T Unnormalize(T coord, int size,
bool align_corners) {
if (align_corners) {
return ((coord + 1.f) / 2) * (size - 1);
} else {
return ((coord + 1.f) * size - 1) / 2;
}
}
template <typename T>
static __forceinline__ __device__ T ClipIndexes(T in, int max_value) {
return min(static_cast<T>(max_value), max(in, static_cast<T>(0)));
}
template <typename T>
static __forceinline__ __device__ T ReflectIndexes(T in, int twice_low,
int twice_high) {
if (twice_low == twice_high) {
return static_cast<T>(0);
}
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = fabs(in - min);
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
if (flips % 2 == 0) {
return extra + min;
} else {
return span - extra + min;
}
}
template <typename T>
static __forceinline__ __device__ T ComputePositions(T coord, int size,
PaddingMode padding_mode,
bool align_corners) {
coord = Unnormalize<T>(coord, size, align_corners);
if (padding_mode == PaddingMode::border) {
coord = ClipIndexes(coord, size - 1);
} else if (padding_mode == PaddingMode::reflect) {
if (align_corners) {
coord = ReflectIndexes(coord, 0, 2 * (size - 1));
} else {
coord = ReflectIndexes(coord, -1, 2 * size - 1);
}
coord = ClipIndexes(coord, size - 1);
}
return coord;
}
template <typename T, typename index_t>
__global__ void GridSample3DCudaKernel(
const index_t nthreads, index_t out_c, index_t out_d, index_t out_h,
index_t out_w, index_t in_d, index_t in_h, index_t in_w, const T* input,
const T* grid, T* output, const Mode interpolation_mode,
const PaddingMode padding_mode, bool align_corners) {
// printf("size: %d, %d, %d, %d, %d, %d \n", out_c, out_d, out_w, out_h, in_d,
// in_w);
index_t inp_sW = 1;
index_t inp_sH = in_w;
index_t inp_sD = in_h * in_w;
index_t inp_sC = in_d * inp_sD;
index_t inp_sN = out_c * inp_sC;
index_t grid_sCoor = 1;
index_t grid_sW = 3;
index_t grid_sH = out_w * grid_sW;
index_t grid_sD = out_h * grid_sH;
index_t grid_sN = out_d * grid_sD;
index_t out_sW = 1;
index_t out_sH = out_w;
index_t out_sD = out_h * out_w;
index_t out_sC = out_d * out_sD;
index_t out_sN = out_c * out_sC;
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t w = index % out_w;
const index_t h = (index / out_w) % out_h;
const index_t d = (index / (out_h * out_w)) % out_d;
const index_t n = index / (out_d * out_h * out_w);
const index_t grid_offset =
n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
// get the corresponding input x, y, z co-ordinates from grid
T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];
T iz = grid[grid_offset + 2 * grid_sCoor];
ix = ComputePositions(ix, in_w, padding_mode, align_corners);
iy = ComputePositions(iy, in_h, padding_mode, align_corners);
iz = ComputePositions(iz, in_d, padding_mode, align_corners);
// printf("ix: %f, iy: %f, iz: %f \n", ix, iy, iz);
if (interpolation_mode == Mode::bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
index_t ix_tnw = static_cast<index_t>(std::floor(ix));
index_t iy_tnw = static_cast<index_t>(std::floor(iy));
index_t iz_tnw = static_cast<index_t>(std::floor(iz));
index_t ix_tne = ix_tnw + 1;
index_t iy_tne = iy_tnw;
index_t iz_tne = iz_tnw;
index_t ix_tsw = ix_tnw;
index_t iy_tsw = iy_tnw + 1;
index_t iz_tsw = iz_tnw;
index_t ix_tse = ix_tnw + 1;
index_t iy_tse = iy_tnw + 1;
index_t iz_tse = iz_tnw;
index_t ix_bnw = ix_tnw;
index_t iy_bnw = iy_tnw;
index_t iz_bnw = iz_tnw + 1;
index_t ix_bne = ix_tnw + 1;
index_t iy_bne = iy_tnw;
index_t iz_bne = iz_tnw + 1;
index_t ix_bsw = ix_tnw;
index_t iy_bsw = iy_tnw + 1;
index_t iz_bsw = iz_tnw + 1;
index_t ix_bse = ix_tnw + 1;
index_t iy_bse = iy_tnw + 1;
index_t iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
T tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
T tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
T tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
T tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
T bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
T bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
T bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
T bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
auto inp_ptr_NC = input + n * inp_sN;
auto out_ptr_NCDHW =
output + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
for (index_t c = 0; c < out_c;
++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
*out_ptr_NCDHW = static_cast<T>(0);
if (InBounds3D(iz_tnw, iy_tnw, ix_tnw, in_d, in_h, in_w)) {
*out_ptr_NCDHW +=
inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] *
tnw;
}
if (InBounds3D(iz_tne, iy_tne, ix_tne, in_d, in_h, in_w)) {
*out_ptr_NCDHW +=
inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] *
tne;
}
if (InBounds3D(iz_tsw, iy_tsw, ix_tsw, in_d, in_h, in_w)) {
*out_ptr_NCDHW +=
inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] *
tsw;
}
if (InBounds3D(iz_tse, iy_tse, ix_tse, in_d, in_h, in_w)) {
*out_ptr_NCDHW +=
inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] *
tse;
}
if (InBounds3D(iz_bnw, iy_bnw, ix_bnw, in_d, in_h, in_w)) {
*out_ptr_NCDHW +=
inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] *
bnw;
}
if (InBounds3D(iz_bne, iy_bne, ix_bne, in_d, in_h, in_w)) {
*out_ptr_NCDHW +=
inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] *
bne;
}
if (InBounds3D(iz_bsw, iy_bsw, ix_bsw, in_d, in_h, in_w)) {
*out_ptr_NCDHW +=
inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] *
bsw;
}
if (InBounds3D(iz_bse, iy_bse, ix_bse, in_d, in_h, in_w)) {
*out_ptr_NCDHW +=
inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] *
bse;
}
}
} else if (interpolation_mode == Mode::nearest) {
index_t ix_nearest = static_cast<index_t>(std::round(ix));
index_t iy_nearest = static_cast<index_t>(std::round(iy));
index_t iz_nearest = static_cast<index_t>(std::round(iz));
// assign nearest neighor pixel value to output pixel
auto inp_ptr_NC = input + n * inp_sN;
auto out_ptr_NCDHW =
output + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
for (index_t c = 0; c < out_c;
++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
if (InBounds3D(iz_nearest, iy_nearest, ix_nearest, in_d, in_h, in_w)) {
*out_ptr_NCDHW =
inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH +
ix_nearest * inp_sW];
} else {
*out_ptr_NCDHW = static_cast<T>(0);
}
}
}
}
}
std::vector<paddle::Tensor> GridSample3DCUDAForward(
const paddle::Tensor& x, const paddle::Tensor& grid,
const std::string& mode, const std::string& padding_mode,
bool align_corners) {
CHECK_INPUT_GPU(x);
CHECK_INPUT_GPU(grid);
PaddingMode enum_padding_mode;
Mode enum_mode;
if (padding_mode == "border") {
enum_padding_mode = PaddingMode::border;
} else if (padding_mode == "reflection") {
enum_padding_mode = PaddingMode::reflect;
} else {
enum_padding_mode = PaddingMode::zeros;
}
if (mode == "nearest") {
enum_mode = Mode::nearest;
} else {
enum_mode = Mode::bilinear;
}
const int n = grid.shape()[0];
const int out_d = grid.shape()[1];
const int out_h = grid.shape()[2];
const int out_w = grid.shape()[3];
const int c = x.shape()[1];
const int in_d = x.shape()[2];
const int in_h = x.shape()[3];
const int in_w = x.shape()[4];
auto output = paddle::full({n, c, out_d, out_h, out_w}, 0,
paddle::DataType::FLOAT32, paddle::GPUPlace());
const int count = static_cast<int>(n * out_d * out_h * out_w);
int max_threads_per_block = 512;
int block_num = (count - 1) / max_threads_per_block + 1;
// printf("size: %d, %d, %d, %d, %d, %d \n", n, c, out_d, out_h, count,
// block_num);
GridSample3DCudaKernel<float, int>
<<<block_num, max_threads_per_block, 0, x.stream()>>>(
count, c, out_d, out_h, out_w, in_d, in_h, in_w, x.data<float>(),
grid.data<float>(), output.data<float>(), enum_mode,
enum_padding_mode, align_corners);
cudaError_t error_check;
error_check = cudaGetLastError();
if (error_check != cudaSuccess) {
printf("%s\n", cudaGetErrorString(error_check));
}
// printf("size: %d, %d, %d, %d, %d, %d \n", n, c, out_d, out_h, count,
// block_num);
return {output};
}
template <typename T>
static __forceinline__ __device__ T UnnormalizeWithMask(T coord, int size,
bool align_corners,
T* grad_in) {
if (align_corners) {
*grad_in = static_cast<T>(size - 1) / 2;
return ((coord + 1.f) / 2) * (size - 1);
} else {
*grad_in = static_cast<T>(size) / 2;
return ((coord + 1.f) * size - 1) / 2;
}
}
template <typename T>
static __forceinline__ __device__ T ClipIndexesWithMask(T in, int clip_limit,
T* grad_in) {
if (in <= static_cast<T>(0)) {
*grad_in = static_cast<T>(0);
return static_cast<T>(0);
} else {
T max = static_cast<T>(clip_limit - 1);
if (in >= max) {
*grad_in = static_cast<T>(0);
return max;
} else {
*grad_in = static_cast<T>(1);
return in;
}
}
}
template <typename T>
static __forceinline__ __device__ T ReflectIndexesWithMask(T in, int twice_low,
int twice_high,
T* grad_in) {
if (twice_low == twice_high) {
*grad_in = static_cast<T>(0);
return static_cast<T>(0);
}
int grad_in_mult_;
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = in - min;
if (in < static_cast<T>(0)) {
grad_in_mult_ = -1;
in = -in;
} else {
grad_in_mult_ = 1;
}
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
if (flips % 2 == 0) {
*grad_in = static_cast<T>(grad_in_mult_);
return extra + min;
} else {
*grad_in = static_cast<T>(-grad_in_mult_);
return span - extra + min;
}
}
template <typename T>
static __forceinline__ __device__ T
ComputePositionsWithMask(T coord, int size, PaddingMode padding_mode,
bool align_corners, T* grad_in) {
T grad_clip, grad_refl;
coord = UnnormalizeWithMask<T>(coord, size, align_corners, grad_in);
if (padding_mode == PaddingMode::border) {
coord = ClipIndexesWithMask(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_clip;
} else if (padding_mode == PaddingMode::reflect) {
if (align_corners) {
coord = ReflectIndexesWithMask(coord, 0, 2 * (size - 1), &grad_refl);
} else {
coord = ReflectIndexesWithMask(coord, -1, 2 * size - 1, &grad_refl);
}
coord = ClipIndexesWithMask(coord, size, &grad_clip);
*grad_in = (*grad_in) * grad_refl * grad_clip;
}
return coord;
}
template <typename T>
static __forceinline__ __device__ void AtomicAdd3D(
T* data, int64_t d, int64_t h, int64_t w, int64_t sD, int64_t sH,
int64_t sW, int64_t D, int64_t H, int64_t W, T delta) {
if (InBounds3D(d, h, w, D, H, W)) {
atomicAdd(data + d * sD + h * sH + w * sW, delta);
}
}
template <typename T, typename index_t>
__global__ void GridSample3DCudaBackwardKernel(
const index_t nthreads, const T* grad_output, const T* input, const T* grid,
index_t out_c, index_t out_d, index_t out_h, index_t out_w, index_t in_d,
index_t in_h, index_t in_w, T* grad_input, T* grad_grid, const Mode mode,
const PaddingMode padding_mode, bool align_corners) {
index_t inp_sW = 1;
index_t inp_sH = in_w;
index_t inp_sD = in_h * in_w;
index_t inp_sC = in_d * inp_sD;
index_t inp_sN = out_c * inp_sC;
index_t grid_sCoor = 1;
index_t grid_sW = 3;
index_t grid_sH = out_w * grid_sW;
index_t grid_sD = out_h * grid_sH;
index_t grid_sN = out_d * grid_sD;
index_t gOut_sW = 1;
index_t gOut_sH = out_w;
index_t gOut_sD = out_h * out_w;
index_t gOut_sC = out_d * gOut_sD;
index_t gOut_sN = out_c * gOut_sC;
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
const index_t w = index % out_w;
const index_t h = (index / out_w) % out_h;
const index_t d = (index / (out_h * out_w)) % out_d;
const index_t n = index / (out_d * out_h * out_w);
const auto grid_offset =
n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
// get the corresponding input x, y, z co-ordinates from grid
T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];
T iz = grid[grid_offset + 2 * grid_sCoor];
// multipliers for gradients on ix, iy, and iz
T gix_mult, giy_mult, giz_mult;
ix = ComputePositionsWithMask(ix, in_w, padding_mode, align_corners,
&gix_mult);
iy = ComputePositionsWithMask(iy, in_h, padding_mode, align_corners,
&giy_mult);
iz = ComputePositionsWithMask(iz, in_d, padding_mode, align_corners,
&giz_mult);
if (mode == Mode::bilinear) {
// get corner pixel values from (x, y, z)
// for 4d, we used north-east-south-west
// for 5d, we add top-bottom
index_t ix_tnw = static_cast<index_t>(std::floor(ix));
index_t iy_tnw = static_cast<index_t>(std::floor(iy));
index_t iz_tnw = static_cast<index_t>(std::floor(iz));
index_t ix_tne = ix_tnw + 1;
index_t iy_tne = iy_tnw;
index_t iz_tne = iz_tnw;
index_t ix_tsw = ix_tnw;
index_t iy_tsw = iy_tnw + 1;
index_t iz_tsw = iz_tnw;
index_t ix_tse = ix_tnw + 1;
index_t iy_tse = iy_tnw + 1;
index_t iz_tse = iz_tnw;
index_t ix_bnw = ix_tnw;
index_t iy_bnw = iy_tnw;
index_t iz_bnw = iz_tnw + 1;
index_t ix_bne = ix_tnw + 1;
index_t iy_bne = iy_tnw;
index_t iz_bne = iz_tnw + 1;
index_t ix_bsw = ix_tnw;
index_t iy_bsw = iy_tnw + 1;
index_t iz_bsw = iz_tnw + 1;
index_t ix_bse = ix_tnw + 1;
index_t iy_bse = iy_tnw + 1;
index_t iz_bse = iz_tnw + 1;
// get surfaces to each neighbor:
T tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
T tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
T tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
T tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
T bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
T bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
T bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
T bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
T gix = static_cast<T>(0), giy = static_cast<T>(0),
giz = static_cast<T>(0);
index_t gOut_offset =
n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
index_t inp_offset_NC = n * inp_sN;
T* gInp_ptr_NC = grad_input + n * inp_sN;
for (index_t c = 0; c < out_c; ++c, gOut_offset += gOut_sC,
gInp_ptr_NC += inp_sC, inp_offset_NC += inp_sC) {
T gOut = grad_output[gOut_offset];
AtomicAdd3D(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, inp_sD, inp_sH, inp_sW,
in_d, in_h, in_w, tnw * gOut);
AtomicAdd3D(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, inp_sD, inp_sH, inp_sW,
in_d, in_h, in_w, tne * gOut);
AtomicAdd3D(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, inp_sD, inp_sH, inp_sW,
in_d, in_h, in_w, tsw * gOut);
AtomicAdd3D(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, inp_sD, inp_sH, inp_sW,
in_d, in_h, in_w, tse * gOut);
AtomicAdd3D(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, inp_sD, inp_sH, inp_sW,
in_d, in_h, in_w, bnw * gOut);
AtomicAdd3D(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, inp_sD, inp_sH, inp_sW,
in_d, in_h, in_w, bne * gOut);
AtomicAdd3D(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, inp_sD, inp_sH, inp_sW,
in_d, in_h, in_w, bsw * gOut);
AtomicAdd3D(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, inp_sD, inp_sH, inp_sW,
in_d, in_h, in_w, bse * gOut);
// calculate grad_grid
if (InBounds3D(iz_tnw, iy_tnw, ix_tnw, in_d, in_h, in_w)) {
T tnw_val = input[inp_offset_NC + iz_tnw * inp_sD + iy_tnw * inp_sH +
ix_tnw * inp_sW];
gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut;
giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut;
giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut;
}
if (InBounds3D(iz_tne, iy_tne, ix_tne, in_d, in_h, in_w)) {
T tne_val = input[inp_offset_NC + iz_tne * inp_sD + iy_tne * inp_sH +
ix_tne * inp_sW];
gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut;
giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut;
giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut;
}
if (InBounds3D(iz_tsw, iy_tsw, ix_tsw, in_d, in_h, in_w)) {
T tsw_val = input[inp_offset_NC + iz_tsw * inp_sD + iy_tsw * inp_sH +
ix_tsw * inp_sW];
gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut;
giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut;
giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut;
}
if (InBounds3D(iz_tse, iy_tse, ix_tse, in_d, in_h, in_w)) {
T tse_val = input[inp_offset_NC + iz_tse * inp_sD + iy_tse * inp_sH +
ix_tse * inp_sW];
gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut;
giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut;
giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut;
}
if (InBounds3D(iz_bnw, iy_bnw, ix_bnw, in_d, in_h, in_w)) {
T bnw_val = input[inp_offset_NC + iz_bnw * inp_sD + iy_bnw * inp_sH +
ix_bnw * inp_sW];
gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut;
giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut;
giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut;
}
if (InBounds3D(iz_bne, iy_bne, ix_bne, in_d, in_h, in_w)) {
T bne_val = input[inp_offset_NC + iz_bne * inp_sD + iy_bne * inp_sH +
ix_bne * inp_sW];
gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut;
giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut;
giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut;
}
if (InBounds3D(iz_bsw, iy_bsw, ix_bsw, in_d, in_h, in_w)) {
T bsw_val = input[inp_offset_NC + iz_bsw * inp_sD + iy_bsw * inp_sH +
ix_bsw * inp_sW];
gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut;
giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut;
giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut;
}
if (InBounds3D(iz_bse, iy_bse, ix_bse, in_d, in_h, in_w)) {
T bse_val = input[inp_offset_NC + iz_bse * inp_sD + iy_bse * inp_sH +
ix_bse * inp_sW];
gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut;
giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut;
giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut;
}
}
if (grad_grid != nullptr) {
T* gGrid_ptr_NDHW = grad_grid + index * grid_sW;
gGrid_ptr_NDHW[0] = gix_mult * gix;
gGrid_ptr_NDHW[1] = giy_mult * giy;
gGrid_ptr_NDHW[2] = giz_mult * giz;
}
} else if (mode == Mode::nearest) {
auto ix_nearest = static_cast<index_t>(std::round(ix));
auto iy_nearest = static_cast<index_t>(std::round(iy));
auto iz_nearest = static_cast<index_t>(std::round(iz));
// assign nearest neighor pixel value to output pixel
index_t gOut_offset =
n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
T* gInp_ptr_NC = grad_input + n * inp_sN;
for (index_t c = 0; c < out_c;
++c, gOut_offset += gOut_sC, gInp_ptr_NC += inp_sC) {
AtomicAdd3D(gInp_ptr_NC, iz_nearest, iy_nearest, ix_nearest, inp_sD,
inp_sH, inp_sW, in_d, in_h, in_w, grad_output[gOut_offset]);
}
if (grad_grid != nullptr) {
T* gGrid_ptr_NDHW = grad_grid + index * grid_sW;
gGrid_ptr_NDHW[0] = static_cast<T>(0);
gGrid_ptr_NDHW[1] = static_cast<T>(0);
gGrid_ptr_NDHW[2] = static_cast<T>(0);
}
}
}
}
std::vector<paddle::Tensor> GridSample3DCUDABackward(
const paddle::Tensor& x, const paddle::Tensor& grid,
const paddle::Tensor& grad_out, const std::string& mode,
const std::string& padding_mode, bool align_corners) {
PaddingMode enum_padding_mode;
Mode enum_mode;
if (padding_mode == "border") {
enum_padding_mode = PaddingMode::border;
} else if (padding_mode == "reflection") {
enum_padding_mode = PaddingMode::reflect;
} else {
enum_padding_mode = PaddingMode::zeros;
}
if (mode == "nearest") {
enum_mode = Mode::nearest;
} else {
enum_mode = Mode::bilinear;
}
const int out_d = grid.shape()[1];
const int out_h = grid.shape()[2];
const int out_w = grid.shape()[3];
const int n = x.shape()[0];
const int c = x.shape()[1];
const int in_d = x.shape()[2];
const int in_h = x.shape()[3];
const int in_w = x.shape()[4];
auto grid_grad_output =
paddle::empty({n, out_d, out_h, out_w, 3}, paddle::DataType::FLOAT32,
paddle::GPUPlace());
auto x_grad_output =
paddle::full({n, c, in_d, in_h, in_w}, 0, paddle::DataType::FLOAT32,
paddle::GPUPlace());
const int count = static_cast<int>(n * out_d * out_h * out_w);
int max_threads_per_block = 512;
int block_num = (count - 1) / max_threads_per_block + 1;
GridSample3DCudaBackwardKernel<float, int>
<<<block_num, max_threads_per_block, 0, x.stream()>>>(
count, grad_out.data<float>(), x.data<float>(), grid.data<float>(), c,
out_d, out_h, out_w, in_d, in_h, in_w, x_grad_output.data<float>(),
grid_grad_output.data<float>(), enum_mode, enum_padding_mode,
align_corners);
return {x_grad_output};
}
} // namespace fastdeploy
} // namespace paddle_custom_ops

View File

@@ -0,0 +1,33 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
#pragma once
#include <cassert>
#include <cmath>
#include <vector>
namespace fastdeploy {
namespace paddle_custom_ops {
#define HOST_DEVICE __host__ __device__
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
enum class Mode { bilinear, nearest };
enum class PaddingMode { zeros, border, reflect };
} // namespace fastdeploy
} // namespace paddle_custom_ops

View File

@@ -0,0 +1,272 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
3D Rotated IoU Calculation (CPU)
Written by Shaoshuai Shi
All Rights Reserved 2020.
*/
#include "iou3d_cpu.h"
#include <math.h>
#include <stdio.h>
#include <vector>
namespace fastdeploy {
namespace paddle_custom_ops {
static inline float min(float a, float b) { return a > b ? b : a; }
static inline float max(float a, float b) { return a > b ? a : b; }
#if defined(_WIN32)
#if defined(EPS)
#undef EPS
#endif
#define EPS 1e-8
#else
static const float EPS = 1e-8;
#endif
struct Point {
float x, y;
Point() {}
Point(double _x, double _y) { x = _x, y = _y; }
void set(float _x, float _y) {
x = _x;
y = _y;
}
Point operator+(const Point &b) const {
return Point(x + b.x, y + b.y);
}
Point operator-(const Point &b) const {
return Point(x - b.x, y - b.y);
}
};
static inline float cross(const Point &a, const Point &b) {
return a.x * b.y - a.y * b.x;
}
static inline float cross(const Point &p1, const Point &p2, const Point &p0) {
return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
}
static inline int check_rect_cross(const Point &p1, const Point &p2, const Point &q1,
const Point &q2) {
int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) &&
min(q1.x, q2.x) <= max(p1.x, p2.x) &&
min(p1.y, p2.y) <= max(q1.y, q2.y) &&
min(q1.y, q2.y) <= max(p1.y, p2.y);
return ret;
}
static inline int check_in_box2d(const float *box, const Point &p) {
// params: (7) [x, y, z, dx, dy, dz, heading]
const float MARGIN = 1e-2;
float center_x = box[0], center_y = box[1];
float angle_cos = cos(-box[6]),
angle_sin =
sin(-box[6]); // rotate the point in the opposite direction of box
float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin);
float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos;
return (fabs(rot_x) < box[3] / 2 + MARGIN &&
fabs(rot_y) < box[4] / 2 + MARGIN);
}
static inline int intersection(const Point &p1, const Point &p0, const Point &q1,
const Point &q0, Point &ans) {
// fast exclusion
if (check_rect_cross(p0, p1, q0, q1) == 0) return 0;
// check cross standing
float s1 = cross(q0, p1, p0);
float s2 = cross(p1, q1, p0);
float s3 = cross(p0, q1, q0);
float s4 = cross(q1, p1, q0);
if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0;
// calculate intersection of two lines
float s5 = cross(q1, p1, p0);
if (fabs(s5 - s1) > EPS) {
ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
} else {
float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
float D = a0 * b1 - a1 * b0;
ans.x = (b0 * c1 - b1 * c0) / D;
ans.y = (a1 * c0 - a0 * c1) / D;
}
return 1;
}
static inline void rotate_around_center(const Point &center, const float angle_cos,
const float angle_sin, Point &p) {
float new_x =
(p.x - center.x) * angle_cos + (p.y - center.y) * (-angle_sin) + center.x;
float new_y =
(p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y;
p.set(new_x, new_y);
}
static inline int point_cmp(const Point &a, const Point &b, const Point &center) {
return atan2(a.y - center.y, a.x - center.x) >
atan2(b.y - center.y, b.x - center.x);
}
static inline float box_overlap(const float *box_a, const float *box_b) {
// params: box_a (7) [x, y, z, dx, dy, dz, heading]
// params: box_b (7) [x, y, z, dx, dy, dz, heading]
// float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 =
// box_a[3], a_angle = box_a[4];
// float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 =
// box_b[3], b_angle = box_b[4];
float a_angle = box_a[6], b_angle = box_b[6];
float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2,
a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2;
float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half;
float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half;
float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half;
float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half;
Point center_a(box_a[0], box_a[1]);
Point center_b(box_b[0], box_b[1]);
Point box_a_corners[5];
box_a_corners[0].set(a_x1, a_y1);
box_a_corners[1].set(a_x2, a_y1);
box_a_corners[2].set(a_x2, a_y2);
box_a_corners[3].set(a_x1, a_y2);
Point box_b_corners[5];
box_b_corners[0].set(b_x1, b_y1);
box_b_corners[1].set(b_x2, b_y1);
box_b_corners[2].set(b_x2, b_y2);
box_b_corners[3].set(b_x1, b_y2);
// get oriented corners
float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);
float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
for (int k = 0; k < 4; k++) {
rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]);
rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]);
}
box_a_corners[4] = box_a_corners[0];
box_b_corners[4] = box_b_corners[0];
// get intersection of lines
Point cross_points[16];
Point poly_center;
int cnt = 0, flag = 0;
poly_center.set(0, 0);
for (int i = 0; i < 4; i++) {
for (int j = 0; j < 4; j++) {
flag = intersection(box_a_corners[i + 1], box_a_corners[i],
box_b_corners[j + 1], box_b_corners[j],
cross_points[cnt]);
if (flag) {
poly_center = poly_center + cross_points[cnt];
cnt++;
}
}
}
// check corners
for (int k = 0; k < 4; k++) {
if (check_in_box2d(box_a, box_b_corners[k])) {
poly_center = poly_center + box_b_corners[k];
cross_points[cnt] = box_b_corners[k];
cnt++;
}
if (check_in_box2d(box_b, box_a_corners[k])) {
poly_center = poly_center + box_a_corners[k];
cross_points[cnt] = box_a_corners[k];
cnt++;
}
}
poly_center.x /= cnt;
poly_center.y /= cnt;
// sort the points of polygon
Point temp;
for (int j = 0; j < cnt - 1; j++) {
for (int i = 0; i < cnt - j - 1; i++) {
if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) {
temp = cross_points[i];
cross_points[i] = cross_points[i + 1];
cross_points[i + 1] = temp;
}
}
}
// get the overlap areas
float area = 0;
for (int k = 0; k < cnt - 1; k++) {
area += cross(cross_points[k] - cross_points[0],
cross_points[k + 1] - cross_points[0]);
}
return fabs(area) / 2.0;
}
static inline float iou_bev(const float *box_a, const float *box_b) {
// params: box_a (7) [x, y, z, dx, dy, dz, heading]
// params: box_b (7) [x, y, z, dx, dy, dz, heading]
float sa = box_a[3] * box_a[4];
float sb = box_b[3] * box_b[4];
float s_overlap = box_overlap(box_a, box_b);
return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
}
int boxes_iou_bev_cpu(paddle::Tensor boxes_a_tensor,
paddle::Tensor boxes_b_tensor,
paddle::Tensor ans_iou_tensor) {
// params boxes_a_tensor: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b_tensor: (M, 7) [x, y, z, dx, dy, dz, heading]
// params ans_iou_tensor: (N, M)
// CHECK_CONTIGUOUS(boxes_a_tensor);
// CHECK_CONTIGUOUS(boxes_b_tensor);
int num_boxes_a = boxes_a_tensor.shape()[0];
int num_boxes_b = boxes_b_tensor.shape()[0];
const float *boxes_a = boxes_a_tensor.data<float>();
const float *boxes_b = boxes_b_tensor.data<float>();
float *ans_iou = ans_iou_tensor.data<float>();
for (int i = 0; i < num_boxes_a; i++) {
for (int j = 0; j < num_boxes_b; j++) {
ans_iou[i * num_boxes_b + j] = iou_bev(boxes_a + i * 7, boxes_b + j * 7);
}
}
return 1;
}
} // namespace fastdeploy
} // namespace paddle_custom_ops

View File

@@ -0,0 +1,35 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
#include "paddle/include/paddle/extension.h"
#else
#include "paddle/extension.h"
#endif
#include "fastdeploy/utils/utils.h"
namespace fastdeploy {
namespace paddle_custom_ops {
FASTDEPLOY_DECL int boxes_iou_bev_cpu(
paddle::Tensor boxes_a_tensor, paddle::Tensor boxes_b_tensor,
paddle::Tensor ans_iou_tensor);
} // namespace fastdeploy
} // namespace paddle_custom_ops

View File

@@ -0,0 +1,237 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
Written by Shaoshuai Shi
All Rights Reserved 2019-2020.
*/
#if defined(WITH_GPU)
#include <cuda.h>
#include <cuda_runtime_api.h>
#include "iou3d_nms.h"
namespace fastdeploy {
namespace paddle_custom_ops {
#define CHECK_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
// #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
static inline int DIVUP(const int m, const int n)
{ return ((m) / (n) + ((m) % (n) > 0)); }
#define CHECK_ERROR(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline void gpuAssert(cudaError_t code, const char *file, int line,
bool abort = true) {
if (code != cudaSuccess) {
fprintf(stderr, "GPUassert: %s %s %d\n", cudaGetErrorString(code), file,
line);
if (abort) exit(code);
}
}
#define D(x) \
PD_THROW('\n', x, \
"\n--------------------------------- where is the error ? " \
"---------------------------------------\n");
static const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
void boxesoverlapLauncher(const int num_a, const float *boxes_a,
const int num_b, const float *boxes_b,
float *ans_overlap);
void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b,
const float *boxes_b, float *ans_iou);
void nmsLauncher(const float *boxes, unsigned long long *mask, int boxes_num,
float nms_overlap_thresh);
void nmsNormalLauncher(const float *boxes, unsigned long long *mask,
int boxes_num, float nms_overlap_thresh);
int boxes_overlap_bev_gpu(paddle::Tensor boxes_a, paddle::Tensor boxes_b,
paddle::Tensor ans_overlap) {
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
// params ans_overlap: (N, M)
CHECK_INPUT(boxes_a);
CHECK_INPUT(boxes_b);
CHECK_INPUT(ans_overlap);
int num_a = boxes_a.shape()[0];
int num_b = boxes_b.shape()[0];
const float *boxes_a_data = boxes_a.data<float>();
const float *boxes_b_data = boxes_b.data<float>();
float *ans_overlap_data = ans_overlap.data<float>();
boxesoverlapLauncher(num_a, boxes_a_data, num_b, boxes_b_data,
ans_overlap_data);
return 1;
}
int boxes_iou_bev_gpu(paddle::Tensor boxes_a, paddle::Tensor boxes_b,
paddle::Tensor ans_iou) {
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
// params ans_overlap: (N, M)
CHECK_INPUT(boxes_a);
CHECK_INPUT(boxes_b);
CHECK_INPUT(ans_iou);
int num_a = boxes_a.shape()[0];
int num_b = boxes_b.shape()[0];
const float *boxes_a_data = boxes_a.data<float>();
const float *boxes_b_data = boxes_b.data<float>();
float *ans_iou_data = ans_iou.data<float>();
boxesioubevLauncher(num_a, boxes_a_data, num_b, boxes_b_data, ans_iou_data);
return 1;
}
std::vector<paddle::Tensor> nms_gpu(const paddle::Tensor &boxes,
float nms_overlap_thresh) {
// params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
// params keep: (N)
CHECK_INPUT(boxes);
// CHECK_CONTIGUOUS(keep);
auto keep = paddle::empty({boxes.shape()[0]}, paddle::DataType::INT32,
paddle::CPUPlace());
auto num_to_keep_tensor =
paddle::empty({1}, paddle::DataType::INT32, paddle::CPUPlace());
int *num_to_keep_data = num_to_keep_tensor.data<int>();
int boxes_num = boxes.shape()[0];
const float *boxes_data = boxes.data<float>();
int *keep_data = keep.data<int>();
int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
unsigned long long *mask_data = NULL;
CHECK_ERROR(cudaMalloc((void **)&mask_data,
boxes_num * col_blocks * sizeof(unsigned long long)));
nmsLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh);
// unsigned long long mask_cpu[boxes_num * col_blocks];
// unsigned long long *mask_cpu = new unsigned long long [boxes_num *
// col_blocks];
std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);
// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data,
boxes_num * col_blocks * sizeof(unsigned long long),
cudaMemcpyDeviceToHost));
cudaFree(mask_data);
// WARN(qiuyanjun): codes below will throw a compile error on windows with
// msvc. Thus, we choosed to use std::vectored to store the result instead.
// unsigned long long remv_cpu[col_blocks];
// memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
std::vector<unsigned long long> remv_cpu(col_blocks, 0);
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / THREADS_PER_BLOCK_NMS;
int inblock = i % THREADS_PER_BLOCK_NMS;
if (!(remv_cpu[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
unsigned long long *p = &mask_cpu[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu[j] |= p[j];
}
}
}
num_to_keep_data[0] = num_to_keep;
if (cudaSuccess != cudaGetLastError()) printf("Error!\n");
return {keep, num_to_keep_tensor};
}
int nms_normal_gpu(paddle::Tensor boxes, paddle::Tensor keep,
float nms_overlap_thresh) {
// params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
// params keep: (N)
CHECK_INPUT(boxes);
// CHECK_CONTIGUOUS(keep);
int boxes_num = boxes.shape()[0];
const float *boxes_data = boxes.data<float>();
// WARN(qiuyanjun): long type for Tensor::data() API is not exported by paddle,
// it will raise some link error on windows with msvc. Please check:
// https://github.com/PaddlePaddle/Paddle/blob/release/2.5/paddle/phi/api/lib/tensor.cc
#if defined(_WIN32)
int *keep_data = keep.data<int>();
#else
long *keep_data = keep.data<long>();
#endif
int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
unsigned long long *mask_data = NULL;
CHECK_ERROR(cudaMalloc((void **)&mask_data,
boxes_num * col_blocks * sizeof(unsigned long long)));
nmsNormalLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh);
// unsigned long long mask_cpu[boxes_num * col_blocks];
// unsigned long long *mask_cpu = new unsigned long long [boxes_num *
// col_blocks];
std::vector<unsigned long long> mask_cpu(boxes_num * col_blocks);
// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data,
boxes_num * col_blocks * sizeof(unsigned long long),
cudaMemcpyDeviceToHost));
cudaFree(mask_data);
// WARN(qiuyanjun): codes below will throw a compile error on windows with
// msvc. Thus, we choosed to use std::vectored to store the result instead.
// unsigned long long remv_cpu[col_blocks];
// memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
std::vector<unsigned long long> remv_cpu(col_blocks, 0);
int num_to_keep = 0;
for (int i = 0; i < boxes_num; i++) {
int nblock = i / THREADS_PER_BLOCK_NMS;
int inblock = i % THREADS_PER_BLOCK_NMS;
if (!(remv_cpu[nblock] & (1ULL << inblock))) {
keep_data[num_to_keep++] = i;
unsigned long long *p = &mask_cpu[0] + i * col_blocks;
for (int j = nblock; j < col_blocks; j++) {
remv_cpu[j] |= p[j];
}
}
}
if (cudaSuccess != cudaGetLastError()) printf("Error!\n");
return num_to_keep;
}
} // namespace fastdeploy
} // namespace paddle_custom_ops
#endif

View File

@@ -0,0 +1,46 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
#include "paddle/include/paddle/extension.h"
#else
#include "paddle/extension.h"
#endif
#include "fastdeploy/utils/utils.h"
#if defined(WITH_GPU)
namespace fastdeploy {
namespace paddle_custom_ops {
FASTDEPLOY_DECL int boxes_overlap_bev_gpu(
paddle::Tensor boxes_a, paddle::Tensor boxes_b,
paddle::Tensor ans_overlap);
FASTDEPLOY_DECL int boxes_iou_bev_gpu(paddle::Tensor boxes_a,
paddle::Tensor boxes_b,
paddle::Tensor ans_iou);
FASTDEPLOY_DECL std::vector<paddle::Tensor> nms_gpu(
const paddle::Tensor& boxes, float nms_overlap_thresh);
FASTDEPLOY_DECL int nms_normal_gpu(
paddle::Tensor boxes, paddle::Tensor keep, float nms_overlap_thresh);
} // namespace fastdeploy
} // namespace paddle_custom_ops
#endif

View File

@@ -0,0 +1,55 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if defined(PADDLEINFERENCE_API_COMPAT_2_4_x)
#include "paddle/include/experimental/ext_all.h"
#elif defined(PADDLEINFERENCE_API_COMPAT_2_5_x)
#include "paddle/include/paddle/extension.h"
#else
#include "paddle/extension.h"
#endif
#include <vector>
#include "iou3d_cpu.h"
#include "iou3d_nms.h"
namespace fastdeploy {
namespace paddle_custom_ops {
std::vector<std::vector<int64_t>> NMSInferShape(
std::vector<int64_t> boxes_shape) {
int64_t keep_num = 1;
return {{boxes_shape[0]}, {keep_num}};
}
std::vector<paddle::DataType> NMSInferDtype(paddle::DataType boxes_dtype) {
return {paddle::DataType::INT64, paddle::DataType::INT64};
}
} // namespace fastdeploy
} // namespace paddle_custom_ops
#if defined(WITH_GPU)
PD_BUILD_OP(nms_gpu)
.Inputs({"boxes"})
.Outputs({"keep", "num_to_keep"})
.Attrs({"nms_overlap_thresh: float"})
.SetKernelFn(PD_KERNEL(fastdeploy::paddle_custom_ops::nms_gpu))
.SetInferDtypeFn(PD_INFER_DTYPE(fastdeploy::paddle_custom_ops::NMSInferDtype))
.SetInferShapeFn(PD_INFER_SHAPE(fastdeploy::paddle_custom_ops::NMSInferShape));
#endif

View File

@@ -1,23 +1,8 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* /*
3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others) 3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
Written by Shaoshuai Shi Written by Shaoshuai Shi
All Rights Reserved 2019-2020. All Rights Reserved 2019-2020.
*/ */
#include <stdio.h> #include <stdio.h>
namespace fastdeploy { namespace fastdeploy {
@@ -78,20 +63,36 @@ __device__ int check_rect_cross(const Point &p1, const Point &p2,
__device__ inline int check_in_box2d(const float *box, const Point &p) { __device__ inline int check_in_box2d(const float *box, const Point &p) {
// params: (7) [x, y, z, dx, dy, dz, heading] // params: (7) [x, y, z, dx, dy, dz, heading]
const float MARGIN = 1e-2; const float MARGIN = 1e-2;
// Align with the setting of mmdet3d
// const float MARGIN = 1e-5;
float center_x = box[0], center_y = box[1]; float center_x = box[0], center_y = box[1];
// rotate the point in the opposite direction of box float angle_cos = cos(-box[6]),
float angle_cos = cos(-box[6]), angle_sin = sin(-box[6]); angle_sin =
sin(-box[6]); // rotate the point in the opposite direction of box
float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin); float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin);
float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos; float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos;
return (fabs(rot_x) < box[3] / 2 + MARGIN && return (fabs(rot_x) < box[3] / 2 + MARGIN &&
fabs(rot_y) < box[4] / 2 + MARGIN); fabs(rot_y) < box[4] / 2 + MARGIN);
// Align with the implement of mmdet3d
// float rot_x =
// (p.x - center_x) * angle_cos + (p.y - center_y) * angle_sin + center_x;
// float rot_y =
// -(p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos +
// center_y;
// float x1 = center_x - box[3] / 2;
// float x2 = center_x + box[3] / 2;
// float y1 = center_y - box[4] / 2;
// float y2 = center_y + box[4] / 2;
// return (rot_x > x1 - MARGIN && rot_x < x2 + MARGIN && rot_y > y1 - MARGIN
// &&
// rot_y < y2 + MARGIN);
} }
__device__ inline int intersection(const Point &p1, const Point &p0, __device__ inline int intersection(const Point &p1, const Point &p0,
const Point &q1, const Point &q0, const Point &q1, const Point &q0,
Point *ans) { Point &ans) {
// fast exclusion // fast exclusion
if (check_rect_cross(p0, p1, q0, q1) == 0) return 0; if (check_rect_cross(p0, p1, q0, q1) == 0) return 0;
@@ -106,16 +107,16 @@ __device__ inline int intersection(const Point &p1, const Point &p0,
// calculate intersection of two lines // calculate intersection of two lines
float s5 = cross(q1, p1, p0); float s5 = cross(q1, p1, p0);
if (fabs(s5 - s1) > EPS) { if (fabs(s5 - s1) > EPS) {
ans->x = (s5 * q0.x - s1 * q1.x) / (s5 - s1); ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
ans->y = (s5 * q0.y - s1 * q1.y) / (s5 - s1); ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
} else { } else {
float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y; float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y; float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
float D = a0 * b1 - a1 * b0; float D = a0 * b1 - a1 * b0;
ans->x = (b0 * c1 - b1 * c0) / D; ans.x = (b0 * c1 - b1 * c0) / D;
ans->y = (a1 * c0 - a0 * c1) / D; ans.y = (a1 * c0 - a0 * c1) / D;
} }
return 1; return 1;
@@ -123,12 +124,18 @@ __device__ inline int intersection(const Point &p1, const Point &p0,
__device__ inline void rotate_around_center(const Point &center, __device__ inline void rotate_around_center(const Point &center,
const float angle_cos, const float angle_cos,
const float angle_sin, Point *p) { const float angle_sin, Point &p) {
float new_x = (p->x - center.x) * angle_cos + // float new_x = (p.x - center.x) * angle_cos + (p.y - center.y) *
(p->y - center.y) * (-angle_sin) + center.x; // (-angle_sin) + center.x;
// float new_y = (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos +
// center.y;
// p.set(new_x, new_y);
// Aligh with the implement of mmdet3d
float new_x =
(p.x - center.x) * angle_cos + (p.y - center.y) * angle_sin + center.x;
float new_y = float new_y =
(p->x - center.x) * angle_sin + (p->y - center.y) * angle_cos + center.y; -(p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y;
p->set(new_x, new_y); p.set(new_x, new_y);
} }
__device__ inline int point_cmp(const Point &a, const Point &b, __device__ inline int point_cmp(const Point &a, const Point &b,
@@ -152,6 +159,14 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) {
Point center_a(box_a[0], box_a[1]); Point center_a(box_a[0], box_a[1]);
Point center_b(box_b[0], box_b[1]); Point center_b(box_b[0], box_b[1]);
#ifdef DEBUG
printf(
"a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\n",
a_x1, a_y1, a_x2, a_y2, a_angle, b_x1, b_y1, b_x2, b_y2, b_angle);
printf("center a: (%.3f, %.3f), b: (%.3f, %.3f)\n", center_a.x, center_a.y,
center_b.x, center_b.y);
#endif
Point box_a_corners[5]; Point box_a_corners[5];
box_a_corners[0].set(a_x1, a_y1); box_a_corners[0].set(a_x1, a_y1);
box_a_corners[1].set(a_x2, a_y1); box_a_corners[1].set(a_x2, a_y1);
@@ -169,8 +184,17 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) {
float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle); float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
for (int k = 0; k < 4; k++) { for (int k = 0; k < 4; k++) {
rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners + k); #ifdef DEBUG
rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners + k); printf("before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k,
box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x,
box_b_corners[k].y);
#endif
rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]);
rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]);
#ifdef DEBUG
printf("corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x,
box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y);
#endif
} }
box_a_corners[4] = box_a_corners[0]; box_a_corners[4] = box_a_corners[0];
@@ -186,10 +210,19 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) {
for (int j = 0; j < 4; j++) { for (int j = 0; j < 4; j++) {
flag = intersection(box_a_corners[i + 1], box_a_corners[i], flag = intersection(box_a_corners[i + 1], box_a_corners[i],
box_b_corners[j + 1], box_b_corners[j], box_b_corners[j + 1], box_b_corners[j],
cross_points + cnt); cross_points[cnt]);
if (flag) { if (flag) {
poly_center = poly_center + cross_points[cnt]; poly_center = poly_center + cross_points[cnt];
cnt++; cnt++;
#ifdef DEBUG
printf(
"Cross points (%.3f, %.3f): a(%.3f, %.3f)->(%.3f, %.3f), b(%.3f, "
"%.3f)->(%.3f, %.3f) \n",
cross_points[cnt - 1].x, cross_points[cnt - 1].y,
box_a_corners[i].x, box_a_corners[i].y, box_a_corners[i + 1].x,
box_a_corners[i + 1].y, box_b_corners[i].x, box_b_corners[i].y,
box_b_corners[i + 1].x, box_b_corners[i + 1].y);
#endif
} }
} }
} }
@@ -200,11 +233,19 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) {
poly_center = poly_center + box_b_corners[k]; poly_center = poly_center + box_b_corners[k];
cross_points[cnt] = box_b_corners[k]; cross_points[cnt] = box_b_corners[k];
cnt++; cnt++;
#ifdef DEBUG
printf("b corners in a: corner_b(%.3f, %.3f)", cross_points[cnt - 1].x,
cross_points[cnt - 1].y);
#endif
} }
if (check_in_box2d(box_b, box_a_corners[k])) { if (check_in_box2d(box_b, box_a_corners[k])) {
poly_center = poly_center + box_a_corners[k]; poly_center = poly_center + box_a_corners[k];
cross_points[cnt] = box_a_corners[k]; cross_points[cnt] = box_a_corners[k];
cnt++; cnt++;
#ifdef DEBUG
printf("a corners in b: corner_a(%.3f, %.3f)", cross_points[cnt - 1].x,
cross_points[cnt - 1].y);
#endif
} }
} }
@@ -223,6 +264,14 @@ __device__ inline float box_overlap(const float *box_a, const float *box_b) {
} }
} }
#ifdef DEBUG
printf("cnt=%d\n", cnt);
for (int i = 0; i < cnt; i++) {
printf("All cross point %d: (%.3f, %.3f)\n", i, cross_points[i].x,
cross_points[i].y);
}
#endif
// get the overlap areas // get the overlap areas
float area = 0; float area = 0;
for (int k = 0; k < cnt - 1; k++) { for (int k = 0; k < cnt - 1; k++) {
@@ -242,10 +291,220 @@ __device__ inline float iou_bev(const float *box_a, const float *box_b) {
return s_overlap / fmaxf(sa + sb - s_overlap, EPS); return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
} }
__global__ void nms_kernel(const int num_bboxes, const int num_bboxes_for_nms, __global__ void boxes_overlap_kernel(const int num_a, const float *boxes_a,
const int num_b, const float *boxes_b,
float *ans_overlap) {
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
if (a_idx >= num_a || b_idx >= num_b) {
return;
}
const float *cur_box_a = boxes_a + a_idx * 7;
const float *cur_box_b = boxes_b + b_idx * 7;
float s_overlap = box_overlap(cur_box_a, cur_box_b);
ans_overlap[a_idx * num_b + b_idx] = s_overlap;
}
__global__ void boxes_iou_bev_kernel(const int num_a, const float *boxes_a,
const int num_b, const float *boxes_b,
float *ans_iou) {
// params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
// params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
if (a_idx >= num_a || b_idx >= num_b) {
return;
}
const float *cur_box_a = boxes_a + a_idx * 7;
const float *cur_box_b = boxes_b + b_idx * 7;
float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);
ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;
}
__global__ void nms_kernel(const int boxes_num, const float nms_overlap_thresh,
const float *boxes, unsigned long long *mask) {
// params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 7 + 0] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0];
block_boxes[threadIdx.x * 7 + 1] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1];
block_boxes[threadIdx.x * 7 + 2] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2];
block_boxes[threadIdx.x * 7 + 3] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3];
block_boxes[threadIdx.x * 7 + 4] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4];
block_boxes[threadIdx.x * 7 + 5] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5];
block_boxes[threadIdx.x * 7 + 6] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
const float *cur_box = boxes + cur_box_idx * 7;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (iou_bev(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
mask[cur_box_idx * col_blocks + col_start] = t;
}
}
__device__ inline float iou_normal(float const *const a, float const *const b) {
// params: a: [x, y, z, dx, dy, dz, heading]
// params: b: [x, y, z, dx, dy, dz, heading]
float left = fmaxf(a[0] - a[3] / 2, b[0] - b[3] / 2),
right = fminf(a[0] + a[3] / 2, b[0] + b[3] / 2);
float top = fmaxf(a[1] - a[4] / 2, b[1] - b[4] / 2),
bottom = fminf(a[1] + a[4] / 2, b[1] + b[4] / 2);
float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f);
float interS = width * height;
float Sa = a[3] * a[4];
float Sb = b[3] * b[4];
return interS / fmaxf(Sa + Sb - interS, EPS);
}
__global__ void nms_normal_kernel(const int boxes_num,
const float nms_overlap_thresh, const float nms_overlap_thresh,
const int decode_bboxes_dims, const float *bboxes, const float *boxes,
const int *index, const int64_t *sorted_index, unsigned long long *mask) {
// params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
// params: mask (N, N/THREADS_PER_BLOCK_NMS)
const int row_start = blockIdx.y;
const int col_start = blockIdx.x;
// if (row_start > col_start) return;
const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS,
THREADS_PER_BLOCK_NMS);
__shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];
if (threadIdx.x < col_size) {
block_boxes[threadIdx.x * 7 + 0] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0];
block_boxes[threadIdx.x * 7 + 1] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1];
block_boxes[threadIdx.x * 7 + 2] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2];
block_boxes[threadIdx.x * 7 + 3] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3];
block_boxes[threadIdx.x * 7 + 4] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4];
block_boxes[threadIdx.x * 7 + 5] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5];
block_boxes[threadIdx.x * 7 + 6] =
boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6];
}
__syncthreads();
if (threadIdx.x < row_size) {
const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
const float *cur_box = boxes + cur_box_idx * 7;
int i = 0;
unsigned long long t = 0;
int start = 0;
if (row_start == col_start) {
start = threadIdx.x + 1;
}
for (i = start; i < col_size; i++) {
if (iou_normal(cur_box, block_boxes + i * 7) > nms_overlap_thresh) {
t |= 1ULL << i;
}
}
int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
mask[cur_box_idx * col_blocks + col_start] = t;
}
}
void boxesoverlapLauncher(const int num_a, const float *boxes_a,
const int num_b, const float *boxes_b,
float *ans_overlap) {
dim3 blocks(
DIVUP(num_b, THREADS_PER_BLOCK),
DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
boxes_overlap_kernel<<<blocks, threads>>>(num_a, boxes_a, num_b, boxes_b,
ans_overlap);
#ifdef DEBUG
cudaDeviceSynchronize(); // for using printf in kernel function
#endif
}
void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b,
const float *boxes_b, float *ans_iou) {
dim3 blocks(
DIVUP(num_b, THREADS_PER_BLOCK),
DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
boxes_iou_bev_kernel<<<blocks, threads>>>(num_a, boxes_a, num_b, boxes_b,
ans_iou);
#ifdef DEBUG
cudaDeviceSynchronize(); // for using printf in kernel function
#endif
}
void nmsLauncher(const float *boxes, unsigned long long *mask, int boxes_num,
float nms_overlap_thresh) {
dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS);
nms_kernel<<<blocks, threads>>>(boxes_num, nms_overlap_thresh, boxes, mask);
}
void nmsNormalLauncher(const float *boxes, unsigned long long *mask,
int boxes_num, float nms_overlap_thresh) {
dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS);
nms_normal_kernel<<<blocks, threads>>>(boxes_num, nms_overlap_thresh, boxes,
mask);
}
__global__ void nms_kernel_centerpoint(const int num_bboxes,
const int num_bboxes_for_nms,
const float nms_overlap_thresh,
const int decode_bboxes_dims,
const float *bboxes, const int *index,
const int64_t *sorted_index,
int64_t *mask) { int64_t *mask) {
// params: boxes (N, 7) [x, y, z, dx, dy, dz, heading] // params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
// params: mask (N, N/THREADS_PER_BLOCK_NMS) // params: mask (N, N/THREADS_PER_BLOCK_NMS)
@@ -304,7 +563,7 @@ __global__ void nms_kernel(const int num_bboxes, const int num_bboxes_for_nms,
t |= 1ULL << i; t |= 1ULL << i;
} }
} }
const int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS); int col_blocks = DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS);
mask[cur_box_idx * col_blocks + col_start] = t; mask[cur_box_idx * col_blocks + col_start] = t;
} }
} }
@@ -317,7 +576,7 @@ void NmsLauncher(const cudaStream_t &stream, const float *bboxes,
dim3 blocks(DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS), dim3 blocks(DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS),
DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS)); DIVUP(num_bboxes_for_nms, THREADS_PER_BLOCK_NMS));
dim3 threads(THREADS_PER_BLOCK_NMS); dim3 threads(THREADS_PER_BLOCK_NMS);
nms_kernel<<<blocks, threads, 0, stream>>>( nms_kernel_centerpoint<<<blocks, threads, 0, stream>>>(
num_bboxes, num_bboxes_for_nms, nms_overlap_thresh, decode_bboxes_dims, num_bboxes, num_bboxes_for_nms, nms_overlap_thresh, decode_bboxes_dims,
bboxes, index, sorted_index, mask); bboxes, index, sorted_index, mask);
} }

View File

@@ -37,6 +37,7 @@
#include "fastdeploy/vision/perception/paddle3d/smoke/smoke.h" #include "fastdeploy/vision/perception/paddle3d/smoke/smoke.h"
#include "fastdeploy/vision/perception/paddle3d/petr/petr.h" #include "fastdeploy/vision/perception/paddle3d/petr/petr.h"
#include "fastdeploy/vision/perception/paddle3d/centerpoint/centerpoint.h" #include "fastdeploy/vision/perception/paddle3d/centerpoint/centerpoint.h"
#include "fastdeploy/vision/perception/paddle3d/caddn/caddn.h"
#include "fastdeploy/vision/detection/ppdet/model.h" #include "fastdeploy/vision/detection/ppdet/model.h"
#include "fastdeploy/vision/facealign/contrib/face_landmark_1000.h" #include "fastdeploy/vision/facealign/contrib/face_landmark_1000.h"
#include "fastdeploy/vision/facealign/contrib/pfld.h" #include "fastdeploy/vision/facealign/contrib/pfld.h"

View File

@@ -0,0 +1,86 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/perception/paddle3d/caddn/caddn.h"
namespace fastdeploy {
namespace vision {
namespace perception {
Caddn::Caddn(const std::string& model_file, const std::string& params_file,
const std::string& config_file, const RuntimeOption& custom_option,
const ModelFormat& model_format)
: preprocessor_(config_file) {
valid_gpu_backends = {Backend::PDINFER};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool Caddn::Initialize() {
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool Caddn::Predict(const cv::Mat& im, std::vector<float>& input_cam_data,
std::vector<float>& input_lidar_data,
PerceptionResult* result) {
std::vector<PerceptionResult> results;
if (!BatchPredict({im}, input_cam_data, input_lidar_data, &results)) {
return false;
}
if (results.size()) {
*result = std::move(results[0]);
}
return true;
}
bool Caddn::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<float>& input_cam_data,
std::vector<float>& input_lidar_data,
std::vector<PerceptionResult>* results) {
std::vector<FDMat> fd_images = WrapMat(images);
if (!preprocessor_.Run(&fd_images, input_cam_data, input_lidar_data,
&reused_input_tensors_)) {
FDERROR << "Failed to preprocess the input image." << std::endl;
return false;
}
reused_input_tensors_[0].name = "images";
reused_input_tensors_[1].name = "trans_cam_to_img";
reused_input_tensors_[2].name = "trans_lidar_to_cam";
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
FDERROR << "Failed to inference by runtime." << std::endl;
return false;
}
if (!postprocessor_.Run(reused_output_tensors_, results)) {
FDERROR << "Failed to postprocess the inference results by runtime."
<< std::endl;
return false;
}
return true;
}
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,83 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. //NOLINT
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/perception/paddle3d/caddn/preprocessor.h"
#include "fastdeploy/vision/perception/paddle3d/caddn/postprocessor.h"
namespace fastdeploy {
namespace vision {
namespace perception {
/*! @brief Caddn model object used when to load a Caddn model exported by Caddn.
*/
class FASTDEPLOY_DECL Caddn : public FastDeployModel {
public:
/** \brief Set path of model file and the configuration of runtime.
*
* \param[in] model_file Path of model file, e.g Caddn/model.pdiparams
* \param[in] params_file Path of parameter file, e.g Caddn/model.pdiparams, if the model format is ONNX, this parameter will be ignored
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends"
* \param[in] model_format Model format of the loaded model, default is Paddle format
*/
Caddn(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
std::string ModelName() const { return "Paddle3D/Caddn"; }
/** \brief Predict the perception result for an input image
*
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] result The output perception result will be writen to this structure
* \return true if the prediction successed, otherwise false
*/
virtual bool Predict(const cv::Mat& im,
std::vector<float>& input_cam_data,
std::vector<float>& input_lidar_data,
PerceptionResult* results);
/** \brief Predict the perception results for a batch of input images
*
* \param[in] imgs, The input image list, each element comes from cv::imread()
* \param[in] results The output perception result list
* \return true if the prediction successed, otherwise false
*/
virtual bool BatchPredict(const std::vector<cv::Mat>& images,
std::vector<float>& input_cam_data,
std::vector<float>& input_lidar_data,
std::vector<PerceptionResult>* results);
/// Get preprocessor reference of Caddn
virtual CaddnPreprocessor& GetPreprocessor() {
return preprocessor_;
}
/// Get postprocessor reference of Caddn
virtual CaddnPostprocessor& GetPostprocessor() {
return postprocessor_;
}
protected:
bool Initialize();
CaddnPreprocessor preprocessor_;
CaddnPostprocessor postprocessor_;
bool initialized_ = false;
};
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,96 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindCaddn(pybind11::module& m) {
pybind11::class_<vision::perception::CaddnPreprocessor,
vision::ProcessorManager>(m, "CaddnPreprocessor")
.def(pybind11::init<std::string>())
.def("run",
[](vision::perception::CaddnPreprocessor& self,
std::vector<pybind11::array>& im_list,
std::vector<float>& cam_data, std::vector<float>& lidar_data) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, cam_data, lidar_data, &outputs)) {
throw std::runtime_error(
"Failed to preprocess the input data in CaddnPreprocessor.");
}
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing();
}
return outputs;
});
pybind11::class_<vision::perception::CaddnPostprocessor>(m,
"CaddnPostprocessor")
.def(pybind11::init<>())
.def("run",
[](vision::perception::CaddnPostprocessor& self,
std::vector<FDTensor>& inputs) {
std::vector<vision::PerceptionResult> results;
if (!self.Run(inputs, &results)) {
throw std::runtime_error(
"Failed to postprocess the runtime result in "
"CaddnPostprocessor.");
}
return results;
})
.def("run", [](vision::perception::CaddnPostprocessor& self,
std::vector<pybind11::array>& input_array) {
std::vector<vision::PerceptionResult> results;
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results)) {
throw std::runtime_error(
"Failed to postprocess the runtime result in "
"CaddnPostprocessor.");
}
return results;
});
pybind11::class_<vision::perception::Caddn, FastDeployModel>(m, "Caddn")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::perception::Caddn& self, pybind11::array& data,
std::vector<float>& cam_data, std::vector<float>& lidar_data) {
auto mat = PyArrayToCvMat(data);
vision::PerceptionResult res;
self.Predict(mat, cam_data, lidar_data, &res);
return res;
})
.def("batch_predict",
[](vision::perception::Caddn& self,
std::vector<pybind11::array>& data, std::vector<float>& cam_data,
std::vector<float>& lidar_data) {
std::vector<cv::Mat> images;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
std::vector<vision::PerceptionResult> results;
self.BatchPredict(images, cam_data, lidar_data, &results);
return results;
})
.def_property_readonly("preprocessor",
&vision::perception::Caddn::GetPreprocessor)
.def_property_readonly("postprocessor",
&vision::perception::Caddn::GetPostprocessor);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,70 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/perception/paddle3d/caddn/postprocessor.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace perception {
CaddnPostprocessor::CaddnPostprocessor() {}
bool CaddnPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<PerceptionResult>* results) {
results->resize(1);
(*results)[0].Clear();
(*results)[0].Reserve(tensors[0].shape[0]);
if (tensors[0].dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
return false;
}
const float* data_0 = reinterpret_cast<const float*>(tensors[0].Data());
auto result = &(*results)[0];
for (int i = 0; i < tensors[0].shape[0] * tensors[0].shape[1]; i += 7) {
// item 1 ~ 3 : box3d bottom center x, y, z
// item 4 ~ 6 : box3d w, h, l
// item 7 : box3d yaw angle
std::vector<float> vec(data_0 + i, data_0 + i + 7);
result->boxes.emplace_back(
std::array<float, 7>{0, 0, 0, 0, vec[3], vec[4], vec[5]});
result->center.emplace_back(std::array<float, 3>{vec[0], vec[1], vec[2]});
result->yaw_angle.push_back(vec[6]);
}
const float* data_1 = reinterpret_cast<const float*>(tensors[2].Data());
for (int i = 0; i < tensors[2].shape[0]; i += 1) {
std::vector<float> vec(data_1 + i, data_1 + i + 1);
result->scores.push_back(vec[0]);
}
const float* data_2 = reinterpret_cast<const float*>(tensors[1].Data());
for (int i = 0; i < tensors[1].shape[0]; i++) {
std::vector<float> vec(data_2 + i, data_2 + i + 1);
result->label_ids.push_back(vec[0]);
}
result->valid.push_back(true); // 0 scores
result->valid.push_back(true); // 1 label_ids
result->valid.push_back(true); // 2 boxes
result->valid.push_back(true); // 3 center
result->valid.push_back(false); // 4 observation_angle
result->valid.push_back(true); // 5 yaw_angle
result->valid.push_back(false); // 6 velocity
return true;
}
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,48 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace perception {
/*! @brief Postprocessor object for Caddn serials model.
*/
class FASTDEPLOY_DECL CaddnPostprocessor {
public:
/** \brief Create a postprocessor instance for Caddn serials model
*/
CaddnPostprocessor();
/** \brief Process the result of runtime and fill to PerceptionResult structure
*
* \param[in] tensors The inference result from runtime
* \param[in] result The output result of detection
* \param[in] ims_info The shape info list, record input_shape and output_shape
* \return true if the postprocess successed, otherwise false
*/
bool Run(const std::vector<FDTensor>& tensors,
std::vector<PerceptionResult>* results);
protected:
float conf_threshold_;
};
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,112 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/perception/paddle3d/caddn/preprocessor.h"
#include "fastdeploy/function/concat.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace perception {
CaddnPreprocessor::CaddnPreprocessor(const std::string& config_file) {
config_file_ = config_file;
FDASSERT(BuildPreprocessPipeline(),
"Failed to create Paddle3DDetPreprocessor.");
initialized_ = true;
}
bool CaddnPreprocessor::BuildPreprocessPipeline() {
processors_.clear();
// preprocess
processors_.push_back(std::make_shared<BGR2RGB>());
std::vector<float> alpha = {1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0};
std::vector<float> beta = {0.0, 0.0, 0.0};
processors_.push_back(std::make_shared<Convert>(alpha, beta));
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(std::make_shared<HWC2CHW>());
// Fusion will improve performance
FuseTransforms(&processors_);
return true;
}
bool CaddnPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<float>& input_cam_data,
std::vector<float>& input_lidar_data,
std::vector<FDTensor>* outputs) {
if (image_batch->mats->empty()) {
FDERROR << "The size of input images should be greater than 0."
<< std::endl;
return false;
}
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
// There are 3 outputs, image, cam_data, lidar_data
outputs->resize(3);
int batch = static_cast<int>(image_batch->mats->size());
// Allocate memory for cam_data
(*outputs)[1].Resize({batch, 3, 4}, FDDataType::FP32);
// Allocate memory for lidar_data
(*outputs)[2].Resize({batch, 4, 4}, FDDataType::FP32);
auto* cam_data_ptr = reinterpret_cast<float*>((*outputs)[1].MutableData());
auto* lidar_data_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
FDMat* mat = &(image_batch->mats->at(i));
for (size_t j = 0; j < processors_.size(); ++j) {
if (!(*(processors_[j].get()))(mat)) {
FDERROR << "Failed to processs image:" << i << " in "
<< processors_[j]->Name() << "." << std::endl;
return false;
}
}
memcpy(cam_data_ptr + i * 12, input_cam_data.data(), 12 * sizeof(float));
memcpy(lidar_data_ptr + i * 16, input_lidar_data.data(),
16 * sizeof(float));
}
FDTensor* tensor = image_batch->Tensor();
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
tensor->Data(), tensor->device,
tensor->device_id);
return true;
}
bool CaddnPreprocessor::Run(std::vector<FDMat>* images,
std::vector<float>& input_cam_data,
std::vector<float>& input_lidar_data,
std::vector<FDTensor>* outputs) {
FDMatBatch image_batch(images);
PreApply(&image_batch);
bool ret = Apply(&image_batch, input_cam_data, input_lidar_data, outputs);
PostApply();
return ret;
}
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/manager.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace perception {
/*! @brief Preprocessor object for Caddn serials model.
*/
class FASTDEPLOY_DECL CaddnPreprocessor : public ProcessorManager {
public:
CaddnPreprocessor() = default;
/** \brief Create a preprocessor instance for Caddn model
*
* \param[in] config_file Path of configuration file for deployment, e.g Caddn/infer_cfg.yml
*/
explicit CaddnPreprocessor(const std::string& config_file);
bool Run(std::vector<FDMat>* images,
std::vector<float>& input_cam_data,
std::vector<float>& input_lidar_data,
std::vector<FDTensor>* outputs);
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \param[in] ims_info The shape info list, record input_shape and output_shape
* \return true if the preprocess successed, otherwise false
*/
bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs) {
FDERROR << "CaddnPreprocessor should input cam and lidar datas" << std::endl;
return 0;
};
bool Apply(FDMatBatch* image_batch,
std::vector<float>& input_cam_data,
std::vector<float>& input_lidar_data,
std::vector<FDTensor>* outputs);
protected:
bool BuildPreprocessPipeline();
std::vector<std::shared_ptr<Processor>> processors_;
bool disable_permute_ = false;
bool initialized_ = false;
std::string config_file_;
};
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -24,8 +24,7 @@ CenterpointPreprocessor::CenterpointPreprocessor(
bool CenterpointPreprocessor::ReadPoint(const std::string &file_path, bool CenterpointPreprocessor::ReadPoint(const std::string &file_path,
const int64_t num_point_dim, const int64_t num_point_dim,
std::vector<float> &data, std::vector<float> &data, int64_t *num_points) {
int64_t *num_points) {
std::ifstream file_in(file_path, std::ios::in | std::ios::binary); std::ifstream file_in(file_path, std::ios::in | std::ios::binary);
if (num_point_dim < 4) { if (num_point_dim < 4) {
FDERROR << "Point dimension must not be less than 4, but received " FDERROR << "Point dimension must not be less than 4, but received "

View File

@@ -19,6 +19,7 @@ namespace fastdeploy {
void BindSmoke(pybind11::module& m); void BindSmoke(pybind11::module& m);
void BindPetr(pybind11::module& m); void BindPetr(pybind11::module& m);
void BindCenterpoint(pybind11::module& m); void BindCenterpoint(pybind11::module& m);
void BindCaddn(pybind11::module& m);
void BindPerception(pybind11::module& m) { void BindPerception(pybind11::module& m) {
auto perception_module = auto perception_module =
@@ -26,5 +27,6 @@ void BindPerception(pybind11::module& m) {
BindSmoke(perception_module); BindSmoke(perception_module);
BindPetr(perception_module); BindPetr(perception_module);
BindCenterpoint(perception_module); BindCenterpoint(perception_module);
BindCaddn(perception_module);
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -16,3 +16,4 @@ from __future__ import absolute_import
from .paddle3d.smoke import * from .paddle3d.smoke import *
from .paddle3d.petr import * from .paddle3d.petr import *
from .paddle3d.centerpoint import * from .paddle3d.centerpoint import *
from .paddle3d.caddn import *

View File

@@ -0,0 +1,108 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
class CaddnPreprocessor:
def __init__(self, config_file):
"""Create a preprocessor for Caddn
"""
self._preprocessor = C.vision.perception.CaddnPreprocessor(config_file)
def run(self, input_ims, cam_data, lidar_data):
"""Preprocess input images for Caddn
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims, cam_data, lidar_data)
class CaddnPostprocessor:
def __init__(self):
"""Create a postprocessor for Caddn
"""
self._postprocessor = C.vision.perception.CaddnPostprocessor()
def run(self, runtime_results):
"""Postprocess the runtime results for Caddn
:param: runtime_results: (list of FDTensor)The output FDTensor results from runtime
:return: list of PerceptionResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
class Caddn(FastDeployModel):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=ModelFormat.PADDLE):
"""Load a Caddn model exported by Caddn.
:param model_file: (str)Path of model file, e.g ./Caddn.pdmodel
:param params_file: (str)Path of parameters file, e.g ./Caddn.pdiparams
:param config_file: (str)Path of config file, e.g ./infer_cfg.yaml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(Caddn, self).__init__(runtime_option)
self._model = C.vision.perception.Caddn(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "Caddn initialize failed."
def predict(self, input_image, cam_data, lidar_data):
"""Detect an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:param: cam_data: (list)The input camera data
:param: lidar_data: (list)The input lidar data
:return: PerceptionResult
"""
return self._model.predict(input_image, cam_data, lidar_data)
def batch_predict(self, images, cam_data, lidar_data):
"""Classify a batch of input image
:param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:param: cam_data: (list)The input camera data
:param: lidar_data: (list)The input lidar data
:return list of PerceptionResult
"""
return self._model.batch_predict(images, cam_data, lidar_data)
@property
def preprocessor(self):
"""Get CaddnPreprocessor object of the loaded model
:return CaddnPreprocessor
"""
return self._model.preprocessor
@property
def postprocessor(self):
"""Get CaddnPostprocessor object of the loaded model
:return CaddnPostprocessor
"""
return self._model.postprocessor