mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-21 07:40:37 +08:00
[Model] Add Paddle3D smoke model (#1766)
* add smoke model * add 3d vis * update code * update doc * mv paddle3d from detection to perception * update result for velocity * update code for CI * add set input data for TRT backend * add serving support for smoke model * update code * update code * update code --------- Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
58
fastdeploy/vision/perception/paddle3d/smoke/postprocessor.cc
Normal file
58
fastdeploy/vision/perception/paddle3d/smoke/postprocessor.cc
Normal file
@@ -0,0 +1,58 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/perception/paddle3d/smoke/postprocessor.h"
|
||||
|
||||
#include "fastdeploy/vision/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace perception {
|
||||
|
||||
SmokePostprocessor::SmokePostprocessor() {}
|
||||
|
||||
bool SmokePostprocessor::Run(const std::vector<FDTensor>& tensors,
|
||||
std::vector<PerceptionResult>* results) {
|
||||
results->resize(1);
|
||||
(*results)[0].Clear();
|
||||
(*results)[0].Reserve(tensors[0].shape[0]);
|
||||
if (tensors[0].dtype != FDDataType::FP32) {
|
||||
FDERROR << "Only support post process with float32 data." << std::endl;
|
||||
return false;
|
||||
}
|
||||
const float* data = reinterpret_cast<const float*>(tensors[0].Data());
|
||||
auto result = &(*results)[0];
|
||||
for (int i = 0; i < tensors[0].shape[0] * tensors[0].shape[1]; i += 14) {
|
||||
// item 1 : class
|
||||
// item 2 : observation angle α
|
||||
// item 3 ~ 6 : box2d x1, y1, x2, y2
|
||||
// item 7 ~ 9 : box3d h, w, l
|
||||
// item 10 ~ 12 : box3d bottom center x, y, z
|
||||
// item 13 : box3d yaw angle
|
||||
// item 14 : score
|
||||
std::vector<float> vec(data + i, data + i + 14);
|
||||
result->scores.push_back(vec[13]);
|
||||
result->label_ids.push_back(vec[0]);
|
||||
result->boxes.emplace_back(std::array<float, 7>{
|
||||
vec[2], vec[3], vec[4], vec[5], vec[6], vec[7], vec[8]});
|
||||
result->center.emplace_back(std::array<float, 3>{vec[9], vec[10], vec[11]});
|
||||
result->observation_angle.push_back(vec[1]);
|
||||
result->yaw_angle.push_back(vec[12]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace perception
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
48
fastdeploy/vision/perception/paddle3d/smoke/postprocessor.h
Executable file
48
fastdeploy/vision/perception/paddle3d/smoke/postprocessor.h
Executable file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
namespace perception {
|
||||
/*! @brief Postprocessor object for Smoke serials model.
|
||||
*/
|
||||
class FASTDEPLOY_DECL SmokePostprocessor {
|
||||
public:
|
||||
/** \brief Create a postprocessor instance for Smoke serials model
|
||||
*/
|
||||
SmokePostprocessor();
|
||||
|
||||
/** \brief Process the result of runtime and fill to PerceptionResult structure
|
||||
*
|
||||
* \param[in] tensors The inference result from runtime
|
||||
* \param[in] result The output result of detection
|
||||
* \param[in] ims_info The shape info list, record input_shape and output_shape
|
||||
* \return true if the postprocess successed, otherwise false
|
||||
*/
|
||||
bool Run(const std::vector<FDTensor>& tensors,
|
||||
std::vector<PerceptionResult>* results);
|
||||
|
||||
|
||||
protected:
|
||||
float conf_threshold_;
|
||||
};
|
||||
|
||||
} // namespace perception
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
161
fastdeploy/vision/perception/paddle3d/smoke/preprocessor.cc
Executable file
161
fastdeploy/vision/perception/paddle3d/smoke/preprocessor.cc
Executable file
@@ -0,0 +1,161 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/perception/paddle3d/smoke/preprocessor.h"
|
||||
|
||||
#include "fastdeploy/function/concat.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace perception {
|
||||
|
||||
SmokePreprocessor::SmokePreprocessor(const std::string& config_file) {
|
||||
config_file_ = config_file;
|
||||
FDASSERT(BuildPreprocessPipelineFromConfig(),
|
||||
"Failed to create Paddle3DDetPreprocessor.");
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
bool SmokePreprocessor::BuildPreprocessPipelineFromConfig() {
|
||||
processors_.clear();
|
||||
YAML::Node cfg;
|
||||
try {
|
||||
cfg = YAML::LoadFile(config_file_);
|
||||
} catch (YAML::BadFile& e) {
|
||||
FDERROR << "Failed to load yaml file " << config_file_
|
||||
<< ", maybe you should check this file." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
// read for preprocess
|
||||
processors_.push_back(std::make_shared<BGR2RGB>());
|
||||
|
||||
bool has_permute = false;
|
||||
for (const auto& op : cfg["Preprocess"]) {
|
||||
std::string op_name = op["type"].as<std::string>();
|
||||
if (op_name == "NormalizeImage") {
|
||||
auto mean = op["mean"].as<std::vector<float>>();
|
||||
auto std = op["std"].as<std::vector<float>>();
|
||||
bool is_scale = true;
|
||||
if (op["is_scale"]) {
|
||||
is_scale = op["is_scale"].as<bool>();
|
||||
}
|
||||
std::string norm_type = "mean_std";
|
||||
if (op["norm_type"]) {
|
||||
norm_type = op["norm_type"].as<std::string>();
|
||||
}
|
||||
if (norm_type != "mean_std") {
|
||||
std::fill(mean.begin(), mean.end(), 0.0);
|
||||
std::fill(std.begin(), std.end(), 1.0);
|
||||
}
|
||||
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
|
||||
} else if (op_name == "Resize") {
|
||||
bool keep_ratio = op["keep_ratio"].as<bool>();
|
||||
auto target_size = op["target_size"].as<std::vector<int>>();
|
||||
int interp = op["interp"].as<int>();
|
||||
FDASSERT(target_size.size() == 2,
|
||||
"Require size of target_size be 2, but now it's %lu.",
|
||||
target_size.size());
|
||||
if (!keep_ratio) {
|
||||
int width = target_size[1];
|
||||
int height = target_size[0];
|
||||
processors_.push_back(
|
||||
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
|
||||
} else {
|
||||
int min_target_size = std::min(target_size[0], target_size[1]);
|
||||
int max_target_size = std::max(target_size[0], target_size[1]);
|
||||
std::vector<int> max_size;
|
||||
if (max_target_size > 0) {
|
||||
max_size.push_back(max_target_size);
|
||||
max_size.push_back(max_target_size);
|
||||
}
|
||||
processors_.push_back(std::make_shared<ResizeByShort>(
|
||||
min_target_size, interp, true, max_size));
|
||||
}
|
||||
} else if (op_name == "Permute") {
|
||||
// Do nothing, do permute as the last operation
|
||||
has_permute = true;
|
||||
continue;
|
||||
} else {
|
||||
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!disable_permute_) {
|
||||
if (has_permute) {
|
||||
// permute = cast<float> + HWC2CHW
|
||||
processors_.push_back(std::make_shared<Cast>("float"));
|
||||
processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
}
|
||||
}
|
||||
|
||||
// Fusion will improve performance
|
||||
FuseTransforms(&processors_);
|
||||
|
||||
input_k_data_ = cfg["k_data"].as<std::vector<float>>();
|
||||
input_ratio_data_ = cfg["ratio_data"].as<std::vector<float>>();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SmokePreprocessor::Apply(FDMatBatch* image_batch,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
if (image_batch->mats->empty()) {
|
||||
FDERROR << "The size of input images should be greater than 0."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if (!initialized_) {
|
||||
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||
return false;
|
||||
}
|
||||
// There are 3 outputs, image, k_data, ratio_data
|
||||
outputs->resize(3);
|
||||
int batch = static_cast<int>(image_batch->mats->size());
|
||||
|
||||
// Allocate memory for k_data
|
||||
(*outputs)[2].Resize({batch, 3, 3}, FDDataType::FP32);
|
||||
|
||||
// Allocate memory for ratio_data
|
||||
(*outputs)[0].Resize({batch, 2}, FDDataType::FP32);
|
||||
|
||||
auto* k_data_ptr = reinterpret_cast<float*>((*outputs)[2].MutableData());
|
||||
|
||||
auto* ratio_data_ptr = reinterpret_cast<float*>((*outputs)[0].MutableData());
|
||||
|
||||
for (size_t i = 0; i < image_batch->mats->size(); ++i) {
|
||||
FDMat* mat = &(image_batch->mats->at(i));
|
||||
for (size_t j = 0; j < processors_.size(); ++j) {
|
||||
if (!(*(processors_[j].get()))(mat)) {
|
||||
FDERROR << "Failed to processs image:" << i << " in "
|
||||
<< processors_[j]->Name() << "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(k_data_ptr + i * 9, input_k_data_.data(), 9 * sizeof(float));
|
||||
memcpy(ratio_data_ptr + i * 2, input_ratio_data_.data(), 2 * sizeof(float));
|
||||
}
|
||||
|
||||
FDTensor* tensor = image_batch->Tensor();
|
||||
(*outputs)[1].SetExternalData(tensor->Shape(), tensor->Dtype(),
|
||||
tensor->Data(), tensor->device,
|
||||
tensor->device_id);
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace perception
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
61
fastdeploy/vision/perception/paddle3d/smoke/preprocessor.h
Executable file
61
fastdeploy/vision/perception/paddle3d/smoke/preprocessor.h
Executable file
@@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/common/processors/manager.h"
|
||||
#include "fastdeploy/vision/common/processors/transform.h"
|
||||
#include "fastdeploy/vision/common/result.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
namespace perception {
|
||||
/*! @brief Preprocessor object for Smoke serials model.
|
||||
*/
|
||||
class FASTDEPLOY_DECL SmokePreprocessor : public ProcessorManager {
|
||||
public:
|
||||
SmokePreprocessor() = default;
|
||||
/** \brief Create a preprocessor instance for Smoke model
|
||||
*
|
||||
* \param[in] config_file Path of configuration file for deployment, e.g smoke/infer_cfg.yml
|
||||
*/
|
||||
explicit SmokePreprocessor(const std::string& config_file);
|
||||
|
||||
/** \brief Process the input image and prepare input tensors for runtime
|
||||
*
|
||||
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||
* \param[in] outputs The output tensors which will feed in runtime
|
||||
* \param[in] ims_info The shape info list, record input_shape and output_shape
|
||||
* \return true if the preprocess successed, otherwise false
|
||||
*/
|
||||
bool Apply(FDMatBatch* image_batch, std::vector<FDTensor>* outputs);
|
||||
|
||||
protected:
|
||||
bool BuildPreprocessPipelineFromConfig();
|
||||
std::vector<std::shared_ptr<Processor>> processors_;
|
||||
|
||||
bool disable_permute_ = false;
|
||||
|
||||
bool initialized_ = false;
|
||||
|
||||
std::string config_file_;
|
||||
|
||||
std::vector<float> input_k_data_;
|
||||
|
||||
std::vector<float> input_ratio_data_;
|
||||
};
|
||||
|
||||
} // namespace perception
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
82
fastdeploy/vision/perception/paddle3d/smoke/smoke.cc
Executable file
82
fastdeploy/vision/perception/paddle3d/smoke/smoke.cc
Executable file
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/perception/paddle3d/smoke/smoke.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace perception {
|
||||
|
||||
Smoke::Smoke(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file, const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format)
|
||||
: preprocessor_(config_file) {
|
||||
valid_cpu_backends = {Backend::PDINFER};
|
||||
valid_gpu_backends = {Backend::PDINFER};
|
||||
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.model_file = model_file;
|
||||
runtime_option.params_file = params_file;
|
||||
initialized = Initialize();
|
||||
}
|
||||
|
||||
bool Smoke::Initialize() {
|
||||
if (!InitRuntime()) {
|
||||
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Smoke::Predict(const cv::Mat& im, PerceptionResult* result) {
|
||||
std::vector<PerceptionResult> results;
|
||||
if (!BatchPredict({im}, &results)) {
|
||||
return false;
|
||||
}
|
||||
if (results.size()) {
|
||||
*result = std::move(results[0]);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Smoke::BatchPredict(const std::vector<cv::Mat>& images,
|
||||
std::vector<PerceptionResult>* results) {
|
||||
std::vector<FDMat> fd_images = WrapMat(images);
|
||||
|
||||
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
|
||||
FDERROR << "Failed to preprocess the input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
reused_input_tensors_[0].name = InputInfoOfRuntime(0).name;
|
||||
reused_input_tensors_[1].name = InputInfoOfRuntime(1).name;
|
||||
reused_input_tensors_[2].name = InputInfoOfRuntime(2).name;
|
||||
|
||||
if (!Infer(reused_input_tensors_, &reused_output_tensors_)) {
|
||||
FDERROR << "Failed to inference by runtime." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!postprocessor_.Run(reused_output_tensors_, results)) {
|
||||
FDERROR << "Failed to postprocess the inference results by runtime."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace perception
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
78
fastdeploy/vision/perception/paddle3d/smoke/smoke.h
Executable file
78
fastdeploy/vision/perception/paddle3d/smoke/smoke.h
Executable file
@@ -0,0 +1,78 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. //NOLINT
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/fastdeploy_model.h"
|
||||
#include "fastdeploy/vision/perception/paddle3d/smoke/preprocessor.h"
|
||||
#include "fastdeploy/vision/perception/paddle3d/smoke/postprocessor.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace perception {
|
||||
/*! @brief smoke model object used when to load a smoke model exported by smoke.
|
||||
*/
|
||||
class FASTDEPLOY_DECL Smoke : public FastDeployModel {
|
||||
public:
|
||||
/** \brief Set path of model file and the configuration of runtime.
|
||||
*
|
||||
* \param[in] model_file Path of model file, e.g smoke/model.pdiparams
|
||||
* \param[in] params_file Path of parameter file, e.g smoke/model.pdiparams, if the model format is ONNX, this parameter will be ignored
|
||||
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in "valid_cpu_backends"
|
||||
* \param[in] model_format Model format of the loaded model, default is Paddle format
|
||||
*/
|
||||
Smoke(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option = RuntimeOption(),
|
||||
const ModelFormat& model_format = ModelFormat::PADDLE);
|
||||
|
||||
std::string ModelName() const { return "Paddle3D/smoke"; }
|
||||
|
||||
/** \brief Predict the perception result for an input image
|
||||
*
|
||||
* \param[in] img The input image data, comes from cv::imread(), is a 3-D array with layout HWC, BGR format
|
||||
* \param[in] result The output perception result will be writen to this structure
|
||||
* \return true if the prediction successed, otherwise false
|
||||
*/
|
||||
virtual bool Predict(const cv::Mat& img, PerceptionResult* result);
|
||||
|
||||
/** \brief Predict the perception results for a batch of input images
|
||||
*
|
||||
* \param[in] imgs, The input image list, each element comes from cv::imread()
|
||||
* \param[in] results The output perception result list
|
||||
* \return true if the prediction successed, otherwise false
|
||||
*/
|
||||
virtual bool BatchPredict(const std::vector<cv::Mat>& imgs,
|
||||
std::vector<PerceptionResult>* results);
|
||||
|
||||
/// Get preprocessor reference of Smoke
|
||||
virtual SmokePreprocessor& GetPreprocessor() {
|
||||
return preprocessor_;
|
||||
}
|
||||
|
||||
/// Get postprocessor reference of Smoke
|
||||
virtual SmokePostprocessor& GetPostprocessor() {
|
||||
return postprocessor_;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool Initialize();
|
||||
SmokePreprocessor preprocessor_;
|
||||
SmokePostprocessor postprocessor_;
|
||||
bool initialized_ = false;
|
||||
};
|
||||
|
||||
} // namespace perception
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
92
fastdeploy/vision/perception/paddle3d/smoke/smoke_pybind.cc
Normal file
92
fastdeploy/vision/perception/paddle3d/smoke/smoke_pybind.cc
Normal file
@@ -0,0 +1,92 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindSmoke(pybind11::module& m) {
|
||||
pybind11::class_<vision::perception::SmokePreprocessor,
|
||||
vision::ProcessorManager>(m, "SmokePreprocessor")
|
||||
.def(pybind11::init<std::string>())
|
||||
.def("run", [](vision::perception::SmokePreprocessor& self,
|
||||
std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||
}
|
||||
std::vector<FDTensor> outputs;
|
||||
if (!self.Run(&images, &outputs)) {
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in SmokePreprocessor.");
|
||||
}
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
outputs[i].StopSharing();
|
||||
}
|
||||
return outputs;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::perception::SmokePostprocessor>(m,
|
||||
"SmokePostprocessor")
|
||||
.def(pybind11::init<>())
|
||||
.def("run",
|
||||
[](vision::perception::SmokePostprocessor& self,
|
||||
std::vector<FDTensor>& inputs) {
|
||||
std::vector<vision::PerceptionResult> results;
|
||||
if (!self.Run(inputs, &results)) {
|
||||
throw std::runtime_error(
|
||||
"Failed to postprocess the runtime result in "
|
||||
"SmokePostprocessor.");
|
||||
}
|
||||
return results;
|
||||
})
|
||||
.def("run", [](vision::perception::SmokePostprocessor& self,
|
||||
std::vector<pybind11::array>& input_array) {
|
||||
std::vector<vision::PerceptionResult> results;
|
||||
std::vector<FDTensor> inputs;
|
||||
PyArrayToTensorList(input_array, &inputs, /*share_buffer=*/true);
|
||||
if (!self.Run(inputs, &results)) {
|
||||
throw std::runtime_error(
|
||||
"Failed to postprocess the runtime result in "
|
||||
"SmokePostprocessor.");
|
||||
}
|
||||
return results;
|
||||
});
|
||||
|
||||
pybind11::class_<vision::perception::Smoke, FastDeployModel>(m, "Smoke")
|
||||
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
|
||||
ModelFormat>())
|
||||
.def("predict",
|
||||
[](vision::perception::Smoke& self, pybind11::array& data) {
|
||||
auto mat = PyArrayToCvMat(data);
|
||||
vision::PerceptionResult res;
|
||||
self.Predict(mat, &res);
|
||||
return res;
|
||||
})
|
||||
.def("batch_predict",
|
||||
[](vision::perception::Smoke& self,
|
||||
std::vector<pybind11::array>& data) {
|
||||
std::vector<cv::Mat> images;
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
images.push_back(PyArrayToCvMat(data[i]));
|
||||
}
|
||||
std::vector<vision::PerceptionResult> results;
|
||||
self.BatchPredict(images, &results);
|
||||
return results;
|
||||
})
|
||||
.def_property_readonly("preprocessor",
|
||||
&vision::perception::Smoke::GetPreprocessor)
|
||||
.def_property_readonly("postprocessor",
|
||||
&vision::perception::Smoke::GetPostprocessor);
|
||||
}
|
||||
} // namespace fastdeploy
|
Reference in New Issue
Block a user