[Model] Support Paddle3D PETR v2 model (#1863)

* Support PETR v2

* make petrv2 precision equal with the origin repo

* delete extra func

* modify review problem

* delete visualize

* Update README_CN.md

* Update README.md

* Update README_CN.md

* fix build problem

* delete external variable and function

---------

Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
CoolCola
2023-05-19 10:45:36 +08:00
committed by GitHub
parent c8ff8b63e8
commit e3b285c762
20 changed files with 1181 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
English | [简体中文](README_CN.md)
# Petr Ready-to-deploy Model
The Petr deployment model implements the Petr model from Paddle3D. For more detailed information about the model, please refer to [Petr Introduction](https://github.com/PaddlePaddle/Paddle3D/tree/develop/docs/models/petr)
## Detailed Deployment Documents
- [Python Deployment](python)
- [C++ Deployment](cpp)

View File

@@ -0,0 +1,12 @@
[English](README.md) | 简体中文
# Petr 准备部署模型
Petr 部署模型实现来自 Paddle3D 的 Petr 模型,模型相关的更多详细信息可以参考[Petr 介绍](https://github.com/PaddlePaddle/Paddle3D/tree/develop/docs/models/petr)
## 详细部署文档
- [Python部署](python)
- [C++部署](cpp)

View File

@@ -0,0 +1,14 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,76 @@
English | [简体中文](README_CN.md)
# Petr C++ Deployment Example
This directory provides an example of `infer.cc` to quickly complete the deployment of Petr on CPU/GPU.
Before deployment, the following two steps need to be confirmed
- 1. The hardware and software environment meets the requirements, refer to [FastDeploy environment requirements](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
- 2. According to the development environment, download the precompiled deployment library and samples code, refer to [FastDeploy prebuilt library](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
Taking CPU inference on Linux as an example, execute the following command in this directory to complete the compilation test. To support this model, you need to ensure FastDeploy version 1.0.6 or higher (x.x.x>=1.0.6)
```bash
mkdir build
cd build
# Download the FastDeploy precompiled library, users can choose the appropriate version to use in the `FastDeploy precompiled library` mentioned above
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz
tar xvf fastdeploy-linux-x64-x.x.x.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x
make -j
wget https://bj.bcebos.com/fastdeploy/models/petr.tar.gz
tar -xf petr.tar.gz
wget https://bj.bcebos.com/fastdeploy/models/petr_test.png
# CPU
./infer_demo petr petr_test.png 0
# GPU
./infer_demo petr petr_test.png 1
```
The above commands are only applicable to Linux or MacOS. For the usage of SDK under Windows, please refer to:
- [How to use FastDeploy C++ SDK in Windows](../../../../../docs/en/faq/use_sdk_on_windows.md)
## Petr C++ interface
### Class Petr
```c++
fastdeploy::vision::perception::Petr(
const string& model_file,
const string& params_file,
const string& config_file,
const RuntimeOption& runtime_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
```
Petr model loading and initialization.
**parameter**
> * **model_file**(str): model file path
> * **params_file**(str): parameter file path
> * **config_file**(str): configuration file path
> * **runtime_option**(RuntimeOption): Backend reasoning configuration, the default is None, that is, the default configuration is used
> * **model_format**(ModelFormat): model format, the default is Paddle format
#### Predict function
> ```c++
> Petr::Predict(cv::Mat* im, PerceptionResult* result)
> ```
>
> Model prediction interface, the input image directly outputs the detection result.
>
> **parameters**
>
> > * **im**: input image, note that it must be in HWC, BGR format
> > * **result**: Detection result, including the detection frame, the confidence of each frame, PerceptionResult description reference [visual model prediction results](../../../../../docs/api /vision_results/)
- [Model Introduction](../../)
- [Python deployment](../python)
- [Vision Model Prediction Results](../../../../../docs/api/vision_results/)
- [How to switch model inference backend engine](../../../../../docs/en/faq/how_to_change_backend.md)

View File

@@ -0,0 +1,76 @@
[English](README.md) | 简体中文
# Petr C++部署示例
本目录下提供 `infer.cc` 快速完成 Petr 在 CPU/GPU 上部署的示例。
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
以Linux上 CPU 推理为例,在本目录执行如下命令即可完成编译测试,支持此模型需保证 FastDeploy 版本1.0.6以上(x.x.x>=1.0.6)
```bash
mkdir build
cd build
# 下载FastDeploy预编译库用户可在上文提到的`FastDeploy预编译库`中自行选择合适的版本使用
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-x.x.x.tgz
tar xvf fastdeploy-linux-x64-x.x.x.tgz
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/fastdeploy-linux-x64-x.x.x
make -j
wget https://bj.bcebos.com/fastdeploy/models/petr.tar.gz
tar -xf petr.tar.gz
wget https://bj.bcebos.com/fastdeploy/models/petr_test.png
# CPU推理
./infer_demo petr petr_test.png 0
# GPU推理
./infer_demo petr petr_test.png 1
```
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md)
## Petr C++ 接口
### Petr 类
```c++
fastdeploy::vision::perception::Petr(
const string& model_file,
const string& params_file,
const string& config_file,
const RuntimeOption& runtime_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
```
Petr模型加载和初始化。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **config_file**(str): 配置文件路径
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为Paddle格式
#### Predict函数
> ```c++
> Petr::Predict(cv::Mat* im, PerceptionResult* result)
> ```
>
> 模型预测接口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果,包括检测框,各个框的置信度, PerceptionResult 说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)

View File

@@ -0,0 +1,84 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void InitAndInfer(const std::string& model_dir, const std::string& images_dir,
const fastdeploy::RuntimeOption& option) {
auto model_file = model_dir + sep + "petrv2_inference.pdmodel";
auto params_file = model_dir + sep + "petrv2_inference.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
fastdeploy::vision::EnableFlyCV();
auto model = fastdeploy::vision::perception::Petr(
model_file, params_file, config_file, option,
fastdeploy::ModelFormat::PADDLE);
assert(model.Initialized());
std::vector<cv::Mat> im_batch;
for (int i = 0; i < 12; i++) {
auto image_file = images_dir + sep + "image" + std::to_string(i) + ".png";
auto im = cv::imread(image_file);
im_batch.emplace_back(im);
}
std::vector<fastdeploy::vision::PerceptionResult> res;
if (!model.BatchPredict(im_batch, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res[0].Str() << std::endl;
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout << "Usage: infer_demo path/to/paddle_model"
"path/to/image "
"run_option, "
"e.g ./infer_demo ./petr ./00000.png 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with paddle-trt"
<< std::endl;
return -1;
}
fastdeploy::RuntimeOption option;
if (std::atoi(argv[3]) == 0) {
option.UseCpu();
} else if (std::atoi(argv[3]) == 1) {
option.UseGpu();
} else if (std::atoi(argv[3]) == 2) {
option.UseGpu();
option.UseTrtBackend();
option.EnablePaddleToTrt();
option.SetTrtInputShape("images", {1, 3, 384, 1280});
option.SetTrtInputShape("down_ratios", {1, 2});
option.SetTrtInputShape("trans_cam_to_img", {1, 3, 3});
option.SetTrtInputData("trans_cam_to_img",
{721.53771973, 0., 609.55932617, 0., 721.53771973,
172.85400391, 0, 0, 1});
option.EnablePaddleTrtCollectShape();
}
option.UsePaddleBackend();
std::string model_dir = argv[1];
std::string test_image = argv[2];
InitAndInfer(model_dir, test_image, option);
return 0;
}

View File

@@ -0,0 +1,63 @@
English | [简体中文](README_CN.md)
# Petr Python Deployment Example
Before deployment, the following two steps need to be confirmed
- 1. The hardware and software environment meets the requirements, refer to [FastDeploy environment requirements](../../../../../docs/en/build_and_install/download_prebuilt_libraries.md)
- 2. FastDeploy Python whl package installation, refer to [FastDeploy Python Installation](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
This directory provides an example of `infer.py` to quickly complete the deployment of Petr on CPU/GPU. Execute the following script to complete
```bash
#Download deployment sample code
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vision/vision/paddle3d/petr/python
wget https://bj.bcebos.com/fastdeploy/models/petr.tar.gz
tar -xf petr.tar.gz
wget https://bj.bcebos.com/fastdeploy/models/petr_test.png
# CPU reasoning
python infer.py --model petr --image petr_test.png --device cpu
# GPU inference
python infer.py --model petr --image petr_test.png --device gpu
```
## Petr Python interface
```python
fastdeploy.vision.detection.Petr(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
```
Petr model loading and initialization.
**parameter**
> * **model_file**(str): model file path
> * **params_file**(str): parameter file path
> * **config_file**(str): configuration file path
> * **runtime_option**(RuntimeOption): Backend reasoning configuration, the default is None, that is, the default configuration is used
> * **model_format**(ModelFormat): model format, the default is Paddle format
### predict function
> ```python
> Petr. predict(image_data)
> ```
>
> Model prediction interface, the input image directly outputs the detection result.
>
> **parameters**
>
> > * **image_data**(np.ndarray): input data, note that it must be in HWC, BGR format
> **Back**
>
> > Return the `fastdeploy.vision.PerceptionResult` structure, structure description reference document [Vision Model Prediction Results](../../../../../docs/api/vision_results/)
## Other documents
- [Petr Model Introduction](..)
- [Petr C++ deployment](../cpp)
- [Description of model prediction results](../../../../../docs/api/vision_results/)
- [How to switch model inference backend engine](../../../../../docs/en/faq/how_to_change_backend.md)

View File

@@ -0,0 +1,65 @@
[English](README.md) | 简体中文
# Petr Python 部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. FastDeploy Python whl 包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
本目录下提供 `infer.py` 快速完成 Petr 在 CPU/GPU上部署的示例。执行如下脚本即可完成
```bash
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd examples/vision/vision/paddle3d/petr/python
wget https://bj.bcebos.com/fastdeploy/models/petr.tar.gz
tar -xf petr.tar.gz
wget https://bj.bcebos.com/fastdeploy/models/petr_test.png
# CPU推理
python infer.py --model petr --image petr_test.png --device cpu
# GPU推理
python infer.py --model petr --image petr_test.png --device gpu
```
## Petr Python接口
```python
fastdeploy.vision.perception.Petr(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
```
Petr模型加载和初始化。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **config_file**(str): 配置文件路径
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为Paddle格式
### predict 函数
> ```python
> Petr.predict(image_data)
> ```
>
> 模型预测结口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **image_data**(np.ndarray): 输入数据注意需为HWCBGR格式
> **返回**
>
> > 返回`fastdeploy.vision.PerceptionResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
## 其它文档
- [Petr 模型介绍](..)
- [Petr C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)

View File

@@ -0,0 +1,45 @@
import fastdeploy as fd
import cv2
import os
from fastdeploy import ModelFormat
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", required=True, help="Path of petr paddle model.")
parser.add_argument(
"--image", required=True, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu(0)
if args.device.lower() == "cpu":
option.use_cpu()
return option
args = parse_arguments()
model_file = os.path.join(args.model, "petrv2_inference.pdmodel")
params_file = os.path.join(args.model, "petrv2_inference.pdiparams")
config_file = os.path.join(args.model, "infer_cfg.yml")
# 配置runtime加载模型
runtime_option = build_option(args)
model = fd.vision.perception.Petr(
model_file, params_file, config_file, runtime_option=runtime_option)
# 预测图片检测结果
im = cv2.imread(args.image)
result = model.predict(im)
print(result)

View File

@@ -35,6 +35,7 @@
#include "fastdeploy/vision/detection/contrib/yolox.h"
#include "fastdeploy/vision/detection/contrib/rknpu2/model.h"
#include "fastdeploy/vision/perception/paddle3d/smoke/smoke.h"
#include "fastdeploy/vision/perception/paddle3d/petr/petr.h"
#include "fastdeploy/vision/detection/ppdet/model.h"
#include "fastdeploy/vision/facealign/contrib/face_landmark_1000.h"
#include "fastdeploy/vision/facealign/contrib/pfld.h"

View File

@@ -0,0 +1,83 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/perception/paddle3d/petr/petr.h"
namespace fastdeploy {
namespace vision {
namespace perception {
Petr::Petr(const std::string& model_file, const std::string& params_file,
const std::string& config_file, const RuntimeOption& custom_option,
const ModelFormat& model_format)
: preprocessor_(config_file) {
valid_cpu_backends = {Backend::PDINFER};
valid_gpu_backends = {Backend::PDINFER};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
runtime_option.paddle_infer_option.enable_mkldnn = false;
initialized = Initialize();
}
bool Petr::Initialize() {
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
return true;
}
bool Petr::Predict(const cv::Mat& im, PerceptionResult* result) {
std::vector<PerceptionResult> results;
if (!BatchPredict({im}, &results)) {
return false;
}
if (results.size()) {
*result = std::move(results[0]);
}
return true;
}
bool Petr::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<PerceptionResult>* results) {
std::vector<FDMat> fd_images = WrapMat(images);
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
FDERROR << "Failed to preprocess the input image." << std::endl;
return false;
}
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
reused_input_tensors_[1].name = InputInfoOfRuntime(1).name;
reused_input_tensors_[2].name = InputInfoOfRuntime(2).name;
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
FDERROR << "Failed to inference by runtime." << std::endl;
return false;
}
if (!postprocessor_.Run(reused_output_tensors_, results)) {
FDERROR << "Failed to postprocess the inference results by runtime."
<< std::endl;
return false;
}
return true;
}
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,78 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. //NOLINT
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/perception/paddle3d/petr/preprocessor.h"
#include "fastdeploy/vision/perception/paddle3d/petr/postprocessor.h"
namespace fastdeploy {
namespace vision {
namespace perception {
/*! @brief petr model object used when to load a petr model exported by petr.
*/
class FASTDEPLOY_DECL Petr : public FastDeployModel {
public:
/** \brief Set path of model file and the configuration of runtime.
*
* \param[in] model_file Path of model file, e.g petr/model.pdiparams
* \param[in] params_file Path of parameter file, e.g petr/model.pdiparams, if the model format is ONNX, this parameter will be ignored
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends"
* \param[in] model_format Model format of the loaded model, default is Paddle format
*/
Petr(const std::string& model_file, const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
std::string ModelName() const { return "Paddle3D/petr"; }
/** \brief Predict the perception result for an input image
*
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
* \param[in] result The output perception result will be writen to this structure
* \return true if the prediction successed, otherwise false
*/
virtual bool Predict(const cv::Mat& img, PerceptionResult* result);
/** \brief Predict the perception results for a batch of input images
*
* \param[in] imgs, The input image list, each element comes from cv::imread()
* \param[in] results The output perception result list
* \return true if the prediction successed, otherwise false
*/
virtual bool BatchPredict(const std::vector<cv::Mat>& imgs,
std::vector<PerceptionResult>* results);
/// Get preprocessor reference of Petr
virtual PetrPreprocessor& GetPreprocessor() {
return preprocessor_;
}
/// Get postprocessor reference of Petr
virtual PetrPostprocessor& GetPostprocessor() {
return postprocessor_;
}
protected:
bool Initialize();
PetrPreprocessor preprocessor_;
PetrPostprocessor postprocessor_;
bool initialized_ = false;
};
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,92 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindPetr(pybind11::module& m) {
pybind11::class_<vision::perception::PetrPreprocessor,
vision::ProcessorManager>(m, "PetrPreprocessor")
.def(pybind11::init<std::string>())
.def("run", [](vision::perception::PetrPreprocessor& self,
std::vector<pybind11::array>& im_list) {
std::vector<vision::FDMat> images;
for (size_t i = 0; i < im_list.size(); ++i) {
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
}
std::vector<FDTensor> outputs;
if (!self.Run(&images, &outputs)) {
throw std::runtime_error(
"Failed to preprocess the input data in PetrPreprocessor.");
}
for (size_t i = 0; i < outputs.size(); ++i) {
outputs[i].StopSharing();
}
return outputs;
});
pybind11::class_<vision::perception::PetrPostprocessor>(m,
"PetrPostprocessor")
.def(pybind11::init<>())
.def("run",
[](vision::perception::PetrPostprocessor& self,
std::vector<FDTensor>& inputs) {
std::vector<vision::PerceptionResult> results;
if (!self.Run(inputs, &results)) {
throw std::runtime_error(
"Failed to postprocess the runtime result in "
"PetrPostprocessor.");
}
return results;
})
.def("run", [](vision::perception::PetrPostprocessor& self,
std::vector<pybind11::array>& input_array) {
std::vector<vision::PerceptionResult> results;
std::vector<FDTensor> inputs;
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
if (!self.Run(inputs, &results)) {
throw std::runtime_error(
"Failed to postprocess the runtime result in "
"PetrPostprocessor.");
}
return results;
});
pybind11::class_<vision::perception::Petr, FastDeployModel>(m, "Petr")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::perception::Petr& self, pybind11::array& data) {
auto mat = PyArrayToCvMat(data);
vision::PerceptionResult res;
self.Predict(mat, &res);
return res;
})
.def("batch_predict",
[](vision::perception::Petr& self,
std::vector<pybind11::array>& data) {
std::vector<cv::Mat> images;
for (size_t i = 0; i < data.size(); ++i) {
images.push_back(PyArrayToCvMat(data[i]));
}
std::vector<vision::PerceptionResult> results;
self.BatchPredict(images, &results);
return results;
})
.def_property_readonly("preprocessor",
&vision::perception::Petr::GetPreprocessor)
.def_property_readonly("postprocessor",
&vision::perception::Petr::GetPostprocessor);
}
} // namespace fastdeploy

View File

@@ -0,0 +1,63 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/perception/paddle3d/petr/postprocessor.h"
#include "fastdeploy/vision/utils/utils.h"
namespace fastdeploy {
namespace vision {
namespace perception {
PetrPostprocessor::PetrPostprocessor() {}
bool PetrPostprocessor::Run(const std::vector<FDTensor>& tensors,
std::vector<PerceptionResult>* results) {
results->resize(1);
(*results)[0].Clear();
(*results)[0].Reserve(tensors[0].shape[0]);
if (tensors[0].dtype != FDDataType::FP32) {
FDERROR << "Only support post process with float32 data." << std::endl;
return false;
}
const float* data_0 = reinterpret_cast<const float*>(tensors[0].Data());
auto result = &(*results)[0];
for (int i = 0; i < tensors[0].shape[0] * tensors[0].shape[1]; i += 9) {
// item 1 ~ 3 : box3d w, h, l
// item 4 ~ 6 : box3d bottom center x, y, z
// item 7 : box3d yaw angle
// item 8 ~ 9 : speed x,y
std::vector<float> vec(data_0 + i, data_0 + i + 9);
result->boxes.emplace_back(std::array<float, 7>{
0, 0, 0, 0, vec[0], vec[1], vec[2]});
result->center.emplace_back(std::array<float, 3>{vec[3], vec[4], vec[5]});
result->yaw_angle.push_back(vec[6]);
result->velocity.push_back(std::array<float, 3>{vec[7], vec[8]});
}
const float* data_1 = reinterpret_cast<const float*>(tensors[1].Data());
for (int i = 0; i < tensors[1].shape[0]; i += 1) {
std::vector<float> vec(data_1 + i, data_1 + i + 1);
result->scores.push_back(vec[0]);
}
const long long* data_2 = reinterpret_cast<const long long*>(tensors[2].Data());
for (int i = 0; i < tensors[2].shape[0]; i++) {
std::vector<long long> vec(data_2 + i, data_2 + i + 1);
result->label_ids.push_back(vec[0]);
}
return true;
}
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,48 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace perception {
/*! @brief Postprocessor object for Petr serials model.
*/
class FASTDEPLOY_DECL PetrPostprocessor {
public:
/** \brief Create a postprocessor instance for Petr serials model
*/
PetrPostprocessor();
/** \brief Process the result of runtime and fill to PerceptionResult structure
*
* \param[in] tensors The inference result from runtime
* \param[in] result The output result of detection
* \param[in] ims_info The shape info list, record input_shape and output_shape
* \return true if the postprocess successed, otherwise false
*/
bool Run(const std::vector<FDTensor>& tensors,
std::vector<PerceptionResult>* results);
protected:
float conf_threshold_;
};
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,196 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/perception/paddle3d/petr/preprocessor.h"
#include <iostream>
#include "fastdeploy/function/concat.h"
#include "yaml-cpp/yaml.h"
namespace fastdeploy {
namespace vision {
namespace perception {
PetrPreprocessor::PetrPreprocessor(const std::string& config_file) {
config_file_ = config_file;
FDASSERT(BuildPreprocessPipelineFromConfig(),
"Failed to create Paddle3DDetPreprocessor.");
initialized_ = true;
}
bool PetrPreprocessor::BuildPreprocessPipelineFromConfig() {
processors_.clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
// read for preprocess
bool has_permute = false;
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "NormalizeImage") {
auto mean = op["mean"].as<std::vector<float>>();
auto std = op["std"].as<std::vector<float>>();
bool is_scale = true;
if (op["is_scale"]) {
is_scale = op["is_scale"].as<bool>();
}
std::string norm_type = "mean_std";
if (op["norm_type"]) {
norm_type = op["norm_type"].as<std::string>();
}
if (norm_type != "mean_std") {
std::fill(mean.begin(), mean.end(), 0.0);
std::fill(std.begin(), std.end(), 1.0);
}
mean_ = mean;
std_ = std;
} else if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size() == 2,
"Require size of target_size be 2, but now it's %lu.",
target_size.size());
if (!keep_ratio) {
int width = target_size[0];
int height = target_size[1];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
std::vector<int> max_size;
if (max_target_size > 0) {
max_size.push_back(max_target_size);
max_size.push_back(max_target_size);
}
processors_.push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_size));
}
} else if (op_name == "Permute") {
// Do nothing, do permute as the last operation
has_permute = true;
continue;
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
if (!disable_permute_) {
if (has_permute) {
// permute = cast<float> + HWC2CHW
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(std::make_shared<HWC2CHW>());
}
}
input_k_data_ = cfg["k_data"].as<std::vector<float>>();
return true;
}
bool PetrPreprocessor::Apply(FDMatBatch* image_batch,
std::vector<FDTensor>* outputs) {
if (image_batch->mats->empty()) {
FDERROR << "The size of input images should be greater than 0."
<< std::endl;
return false;
}
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
return false;
}
// There are 3 outputs, image, k_data, timestamp
outputs->resize(3);
int batch = static_cast<int>(image_batch->mats->size());
// Allocate memory for k_data
(*outputs)[1].Resize({1, batch, 4, 4}, FDDataType::FP32);
// Allocate memory for image_data
(*outputs)[0].Resize({1, batch, 3, 320, 800}, FDDataType::FP32);
// Allocate memory for timestamp
(*outputs)[2].Resize({1, batch}, FDDataType::FP32);
auto* image_ptr = reinterpret_cast<float*>((*outputs)[0].MutableData());
auto* k_data_ptr = reinterpret_cast<float*>((*outputs)[1].MutableData());
auto* timestamp_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
FDMat* mat = &(image_batch->mats->at(i));
for (size_t j = 0; j < processors_.size(); ++j) {
if (!(*(processors_[j].get()))(mat)) {
FDERROR << "Failed to processs image:" << i << " in "
<< processors_[j]->Name() << "." << std::endl;
return false;
}
if (processors_[j]->Name() == "Resize") {
// crop and normalize after Resize
auto img = *(mat->GetOpenCVMat());
cv::Mat crop_img = img(cv::Range(130, 450), cv::Range(0, 800));
Normalize(&crop_img, mean_, std_, scale_);
FDMat fd_mat = WrapMat(crop_img);
image_batch->mats->at(i) = fd_mat;
}
}
}
for (int i = 0; i < batch / 2 * 4 * 4; ++i) {
input_k_data_.emplace_back(input_k_data_[i]);
}
memcpy(k_data_ptr, input_k_data_.data(), batch * 16 * sizeof(float));
std::vector<float> timestamp(batch, 0.0f);
for (int i = batch / 2; i < batch; ++i) {
timestamp[i] = 1.0f;
}
memcpy(timestamp_ptr, timestamp.data(), batch * sizeof(float));
FDTensor* tensor = image_batch->Tensor();
(*outputs)[0].SetExternalData(tensor->Shape(), tensor->Dtype(),
tensor->Data(), tensor->device,
tensor->device_id);
return true;
}
void PetrPreprocessor::Normalize(cv::Mat* im, const std::vector<float>& mean,
const std::vector<float>& std, float& scale) {
if (scale) {
(*im).convertTo(*im, CV_32FC3, scale);
}
for (int h = 0; h < im->rows; h++) {
for (int w = 0; w < im->cols; w++) {
im->at<cv::Vec3f>(h, w)[0] =
(im->at<cv::Vec3f>(h, w)[0] - mean[0]) / std[0];
im->at<cv::Vec3f>(h, w)[1] =
(im->at<cv::Vec3f>(h, w)[1] - mean[1]) / std[1];
im->at<cv::Vec3f>(h, w)[2] =
(im->at<cv::Vec3f>(h, w)[2] - mean[2]) / std[2];
}
}
}
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,66 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/manager.h"
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/vision/common/result.h"
namespace fastdeploy {
namespace vision {
namespace perception {
/*! @brief Preprocessor object for Petr serials model.
*/
class FASTDEPLOY_DECL PetrPreprocessor : public ProcessorManager {
public:
PetrPreprocessor() = default;
/** \brief Create a preprocessor instance for Petr model
*
* \param[in] config_file Path of configuration file for deployment, e.g smoke/infer_cfg.yml
*/
explicit PetrPreprocessor(const std::string& config_file);
/** \brief Process the input image and prepare input tensors for runtime
*
* \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime
* \param[in] ims_info The shape info list, record input_shape and output_shape
* \return true if the preprocess successed, otherwise false
*/
bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs);
void Normalize(cv::Mat *im, const std::vector<float> &mean,
const std::vector<float> &std, float &scale);
protected:
bool BuildPreprocessPipelineFromConfig();
std::vector<std::shared_ptr<Processor>> processors_;
bool disable_permute_ = false;
bool initialized_ = false;
std::string config_file_;
float scale_ = 1.0f;
std::vector<float> mean_;
std::vector<float> std_;
std::vector<float> input_k_data_;
};
} // namespace perception
} // namespace vision
} // namespace fastdeploy

View File

@@ -17,6 +17,7 @@
namespace fastdeploy {
void BindSmoke(pybind11::module& m);
void BindPetr(pybind11::module& m);
void BindPerception(pybind11::module& m) {
auto perception_module =

View File

@@ -14,3 +14,4 @@
from __future__ import absolute_import
from .paddle3d.smoke import *
from .paddle3d.petr import *

View File

@@ -0,0 +1,106 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import logging
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
class PetrPreprocessor:
def __init__(self, config_file):
"""Create a preprocessor for Petr
"""
self._preprocessor = C.vision.perception.PetrPreprocessor(config_file)
def run(self, input_ims):
"""Preprocess input images for Petr
:param: input_ims: (list of numpy.ndarray)The input image
:return: list of FDTensor
"""
return self._preprocessor.run(input_ims)
class PetrPostprocessor:
def __init__(self):
"""Create a postprocessor for Petr
"""
self._postprocessor = C.vision.perception.PetrPostprocessor()
def run(self, runtime_results):
"""Postprocess the runtime results for Petr
:param: runtime_results: (list of FDTensor)The output FDTensor results from runtime
:return: list of PerceptionResult(If the runtime_results is predict by batched samples, the length of this list equals to the batch size)
"""
return self._postprocessor.run(runtime_results)
class Petr(FastDeployModel):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=ModelFormat.PADDLE):
"""Load a SMoke model exported by Petr.
:param model_file: (str)Path of model file, e.g ./petr.pdmodel
:param params_file: (str)Path of parameters file, e.g ./petr.pdiparams
:param config_file: (str)Path of config file, e.g ./infer_cfg.yaml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(Petr, self).__init__(runtime_option)
self._model = C.vision.perception.Petr(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "Petr initialize failed."
def predict(self, input_image):
"""Detect an input image
:param input_image: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:param conf_threshold: confidence threshold for postprocessing, default is 0.25
:param nms_iou_threshold: iou threshold for NMS, default is 0.5
:return: PerceptionResult
"""
return self._model.predict(input_image)
def batch_predict(self, images):
"""Classify a batch of input image
:param im: (list of numpy.ndarray) The input image list, each element is a 3-D array with layout HWC, BGR format
:return list of PerceptionResult
"""
return self._model.batch_predict(images)
@property
def preprocessor(self):
"""Get PetrPreprocessor object of the loaded model
:return PetrPreprocessor
"""
return self._model.preprocessor
@property
def postprocessor(self):
"""Get PetrPostprocessor object of the loaded model
:return PetrPostprocessor
"""
return self._model.postprocessor