[Model] [Part 1] Refactor PaddleClas module (#508)

* Split PaddleClas Module refactor

* Split PaddleClas Module refactor

* fix bug
This commit is contained in:
Jason
2022-11-07 15:09:00 +08:00
committed by GitHub
parent 40b099ac99
commit 6633fa3db9
17 changed files with 227 additions and 269 deletions

View File

@@ -131,6 +131,7 @@ void FDTensor::Resize(const std::vector<int64_t>& new_shape,
const FDDataType& data_type,
const std::string& tensor_name,
const Device& new_device) {
external_data_ptr = nullptr;
name = tensor_name;
device = new_device;
dtype = data_type;

View File

@@ -93,6 +93,12 @@ struct FASTDEPLOY_DECL FDTensor {
// Total number of elements in this tensor
int Numel() const;
// Get shape of FDTensor
std::vector<int64_t> Shape() const { return shape; }
// Get dtype of FDTensor
FDDataType Dtype() const { return dtype; }
void Resize(size_t nbytes);
void Resize(const std::vector<int64_t>& new_shape);

View File

@@ -0,0 +1,35 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindFDTensor(pybind11::module& m) {
pybind11::class_<FDTensor>(m, "FDTensor")
.def(pybind11::init<>(), "Default Constructor")
.def_readwrite("name", &FDTensor::name)
.def_readonly("shape", &FDTensor::shape)
.def_readonly("dtype", &FDTensor::dtype)
.def_readonly("device", &FDTensor::device)
.def("numpy", [](FDTensor& self) {
return TensorToPyArray(self);
})
.def("from_numpy", [](FDTensor& self, pybind11::array& pyarray, bool share_buffer = false) {
PyArrayToTensor(pyarray, &self, share_buffer);
});
}
} // namespace fastdeploy

View File

@@ -16,6 +16,7 @@
namespace fastdeploy {
void BindFDTensor(pybind11::module&);
void BindRuntime(pybind11::module&);
void BindFDModel(pybind11::module&);
void BindVision(pybind11::module&);
@@ -70,7 +71,7 @@ void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
data_shape.insert(data_shape.begin(), pyarray.shape(),
pyarray.shape() + pyarray.ndim());
if (share_buffer) {
tensor-> SetExternalData(data_shape, dtype,
tensor->SetExternalData(data_shape, dtype,
pyarray.mutable_data());
} else {
tensor->Resize(data_shape, dtype);
@@ -80,6 +81,7 @@ void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
void PyArrayToTensorList(std::vector<pybind11::array>& pyarrays, std::vector<FDTensor>* tensors,
bool share_buffer) {
tensors->resize(pyarrays.size());
for(auto i = 0; i < pyarrays.size(); ++i) {
PyArrayToTensor(pyarrays[i], &(*tensors)[i], share_buffer);
}
@@ -88,7 +90,7 @@ void PyArrayToTensorList(std::vector<pybind11::array>& pyarrays, std::vector<FDT
pybind11::array TensorToPyArray(const FDTensor& tensor) {
auto numpy_dtype = FDDataTypeToNumpyDataType(tensor.dtype);
auto out = pybind11::array(numpy_dtype, tensor.shape);
memcpy(out.mutable_data(), tensor.Data(), tensor.Numel() * FDDataTypeSize(tensor.dtype));
memcpy(out.mutable_data(), tensor.CpuData(), tensor.Nbytes());
return out;
}
@@ -149,6 +151,7 @@ PYBIND11_MODULE(@PY_LIBRARY_NAME@, m) {
"Make programer easier to deploy deeplearning model, save time to save "
"the world!";
BindFDTensor(m);
BindRuntime(m);
BindFDModel(m);
#ifdef ENABLE_VISION

View File

@@ -17,6 +17,7 @@
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/eval.h>
#include <type_traits>

View File

@@ -162,6 +162,25 @@ void BindRuntime(pybind11::module& m) {
}
return results;
})
.def("infer", [](Runtime& self, std::map<std::string, FDTensor>& data) {
std::vector<FDTensor> inputs;
inputs.reserve(data.size());
for (auto iter = data.begin(); iter != data.end(); ++iter) {
FDTensor tensor;
tensor.SetExternalData(iter->second.Shape(), iter->second.Dtype(), iter->second.Data(), iter->second.device);
tensor.name = iter->first;
inputs.push_back(tensor);
}
std::vector<FDTensor> outputs;
if (!self.Infer(inputs, &outputs)) {
pybind11::eval("raise Exception('Failed to inference with Runtime.')");
}
return outputs;
})
.def("infer", [](Runtime& self, std::vector<FDTensor>& inputs) {
std::vector<FDTensor> outputs;
return self.Infer(inputs, &outputs);
})
.def("num_inputs", &Runtime::NumInputs)
.def("num_outputs", &Runtime::NumOutputs)
.def("get_input_info", &Runtime::GetInputInfo)
@@ -202,33 +221,6 @@ void BindRuntime(pybind11::module& m) {
.value("FP64", FDDataType::FP64)
.value("UINT8", FDDataType::UINT8);
pybind11::class_<FDTensor>(m, "FDTensor", pybind11::buffer_protocol())
.def(pybind11::init())
.def("cpu_data",
[](FDTensor& self) {
auto ptr = self.CpuData();
auto numel = self.Numel();
auto dtype = FDDataTypeToNumpyDataType(self.dtype);
auto base = pybind11::array(dtype, self.shape);
return pybind11::array(dtype, self.shape, ptr, base);
})
.def("resize", static_cast<void (FDTensor::*)(size_t)>(&FDTensor::Resize))
.def("resize",
static_cast<void (FDTensor::*)(const std::vector<int64_t>&)>(
&FDTensor::Resize))
.def(
"resize",
[](FDTensor& self, const std::vector<int64_t>& shape,
const FDDataType& dtype, const std::string& name,
const Device& device) { self.Resize(shape, dtype, name, device); })
.def("numel", &FDTensor::Numel)
.def("nbytes", &FDTensor::Nbytes)
.def_readwrite("name", &FDTensor::name)
.def_readwrite("is_pinned_memory", &FDTensor::is_pinned_memory)
.def_readonly("shape", &FDTensor::shape)
.def_readonly("dtype", &FDTensor::dtype)
.def_readonly("device", &FDTensor::device);
m.def("get_available_backends", []() { return GetAvailableBackends(); });
}

View File

@@ -1,88 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/common/processors/limit_long.h"
namespace fastdeploy {
namespace vision {
bool LimitLong::ImplByOpenCV(Mat* mat) {
cv::Mat* im = mat->GetOpenCVMat();
int origin_w = im->cols;
int origin_h = im->rows;
int im_size_max = std::max(origin_w, origin_h);
int target = im_size_max;
if (max_long_ > 0 && im_size_max > max_long_) {
target = max_long_;
} else if (min_long_ > 0 && im_size_max < min_long_) {
target = min_long_;
}
if (target != im_size_max) {
double scale =
static_cast<double>(target) / static_cast<double>(im_size_max);
cv::resize(*im, *im, cv::Size(), scale, scale, interp_);
mat->SetWidth(im->cols);
mat->SetHeight(im->rows);
}
return true;
}
#ifdef ENABLE_FLYCV
bool LimitLong::ImplByFlyCV(Mat* mat) {
fcv::Mat* im = mat->GetFlyCVMat();
int origin_w = im->width();
int origin_h = im->height();
int im_size_max = std::max(origin_w, origin_h);
int target = im_size_max;
if (max_long_ > 0 && im_size_max > max_long_) {
target = max_long_;
} else if (min_long_ > 0 && im_size_max < min_long_) {
target = min_long_;
}
if (target != im_size_max) {
double scale =
static_cast<double>(target) / static_cast<double>(im_size_max);
if (fabs(scale - 1.0) < 1e-06) {
return true;
}
auto interp_method = fcv::InterpolationType::INTER_LINEAR;
if (interp_ == 0) {
interp_method = fcv::InterpolationType::INTER_NEAREST;
} else if (interp_ == 1) {
interp_method = fcv::InterpolationType::INTER_LINEAR;
} else if (interp_ == 2) {
interp_method = fcv::InterpolationType::INTER_CUBIC;
} else {
FDERROR << "LimitLong: Only support interp_ be 0/1/2 with FlyCV, but "
"now it's "
<< interp_ << "." << std::endl;
return false;
}
fcv::Mat new_im;
fcv::resize(*im, new_im, fcv::Size(), scale, scale, interp_method);
mat->SetMat(new_im);
mat->SetWidth(new_im.width());
mat->SetHeight(new_im.height());
}
return true;
}
#endif
bool LimitLong::Run(Mat* mat, int max_long, int min_long, int interp,
ProcLib lib) {
auto l = LimitLong(max_long, min_long, interp);
return l(mat, lib);
}
} // namespace vision
} // namespace fastdeploy

View File

@@ -1,51 +0,0 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/base.h"
namespace fastdeploy {
namespace vision {
class FASTDEPLOY_DECL LimitLong : public Processor {
public:
explicit LimitLong(int max_long = -1, int min_long = -1, int interp = 1) {
max_long_ = max_long;
min_long_ = min_long;
interp_ = interp;
}
// Limit the long edge of image.
// If the long edge is larger than max_long_, resize the long edge
// to max_long_, while scale the short edge proportionally.
// If the long edge is smaller than min_long_, resize the long edge
// to min_long_, while scale the short edge proportionally.
bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
#endif
std::string Name() { return "LimitLong"; }
static bool Run(Mat* mat, int max_long = -1, int min_long = -1,
int interp = 1, ProcLib lib = ProcLib::DEFAULT);
int GetMaxLong() const { return max_long_; }
private:
int max_long_;
int min_long_;
int interp_;
};
} // namespace vision
} // namespace fastdeploy

View File

@@ -65,7 +65,7 @@ bool LimitShort::ImplByFlyCV(Mat* mat) {
} else if (interp_ == 2) {
interp_method = fcv::InterpolationType::INTER_CUBIC;
} else {
FDERROR << "LimitLong: Only support interp_ be 0/1/2 with FlyCV, but "
FDERROR << "LimitShort: Only support interp_ be 0/1/2 with FlyCV, but "
"now it's "
<< interp_ << "." << std::endl;
return false;

View File

@@ -174,5 +174,18 @@ Mat Mat::Create(int height, int width, int channels,
return mat;
}
FDMat WrapMat(const cv::Mat& image) {
FDMat mat(image);
return mat;
}
std::vector<FDMat> WrapMat(const std::vector<cv::Mat>& images) {
std::vector<FDMat> mats;
for (size_t i = 0; i < images.size(); ++i) {
mats.emplace_back(FDMat(images[i]));
}
return mats;
}
} // namespace vision
} // namespace fastdeploy

View File

@@ -147,5 +147,15 @@ struct FASTDEPLOY_DECL Mat {
FDDataType type, void* data, ProcLib lib);
};
typedef Mat FDMat;
/*
* @brief Wrap a cv::Mat to FDMat, there's no memory copy, memory buffer is managed by user
*/
FASTDEPLOY_DECL FDMat WrapMat(const cv::Mat& image);
/*
* Warp a vector<cv::Mat> to vector<FDMat>, there's no memory copy, memory buffer is managed by user
*/
FASTDEPLOY_DECL std::vector<FDMat> WrapMat(const std::vector<cv::Mat>& images);
} // namespace vision
} // namespace fastdeploy

View File

@@ -79,7 +79,7 @@ bool Resize::ImplByFlyCV(Mat* mat) {
} else if (interp_ == 2) {
interp_method = fcv::InterpolationType::INTER_CUBIC;
} else {
FDERROR << "LimitLong: Only support interp_ be 0/1/2 with FlyCV, but "
FDERROR << "Resize: Only support interp_ be 0/1/2 with FlyCV, but "
"now it's "
<< interp_ << "." << std::endl;
return false;

View File

@@ -0,0 +1,106 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/common/processors/transform.h"
namespace fastdeploy {
namespace vision {
void FuseNormalizeCast(
std::vector<std::shared_ptr<Processor>>* processors) {
// Fuse Normalize and Cast<Float>
int cast_index = -1;
for (size_t i = 0; i < processors->size(); ++i) {
if ((*processors)[i]->Name() == "Cast") {
if (i == 0) {
continue;
}
if ((*processors)[i - 1]->Name() != "Normalize" &&
(*processors)[i - 1]->Name() != "NormalizeAndPermute") {
continue;
}
cast_index = i;
}
}
if (cast_index < 0) {
return;
}
if (dynamic_cast<Cast*>((*processors)[cast_index].get())->GetDtype() !=
"float") {
return;
}
processors->erase(processors->begin() + cast_index);
FDINFO << (*processors)[cast_index - 1]->Name() << " and Cast are fused to "
<< (*processors)[cast_index - 1]->Name()
<< " in preprocessing pipeline." << std::endl;
}
void FuseNormalizeHWC2CHW(
std::vector<std::shared_ptr<Processor>>* processors) {
// Fuse Normalize and HWC2CHW to NormalizeAndPermute
int hwc2chw_index = -1;
for (size_t i = 0; i < processors->size(); ++i) {
if ((*processors)[i]->Name() == "HWC2CHW") {
if (i == 0) {
continue;
}
if ((*processors)[i - 1]->Name() != "Normalize") {
continue;
}
hwc2chw_index = i;
}
}
if (hwc2chw_index < 0) {
return;
}
// Get alpha and beta of Normalize
std::vector<float> alpha =
dynamic_cast<Normalize*>((*processors)[hwc2chw_index - 1].get())
->GetAlpha();
std::vector<float> beta =
dynamic_cast<Normalize*>((*processors)[hwc2chw_index - 1].get())
->GetBeta();
// Delete Normalize and HWC2CHW
processors->erase(processors->begin() + hwc2chw_index);
processors->erase(processors->begin() + hwc2chw_index - 1);
// Add NormalizeAndPermute
std::vector<float> mean({0.0, 0.0, 0.0});
std::vector<float> std({1.0, 1.0, 1.0});
processors->push_back(std::make_shared<NormalizeAndPermute>(mean, std));
// Set alpha and beta
auto processor = dynamic_cast<NormalizeAndPermute*>(
(*processors)[hwc2chw_index - 1].get());
processor->SetAlpha(alpha);
processor->SetBeta(beta);
FDINFO << "Normalize and HWC2CHW are fused to NormalizeAndPermute "
" in preprocessing pipeline."
<< std::endl;
}
void FuseTransforms(
std::vector<std::shared_ptr<Processor>>* processors) {
FuseNormalizeCast(processors);
FuseNormalizeHWC2CHW(processors);
}
} // namespace vision
} // namespace fastdeploy

View File

@@ -21,7 +21,6 @@
#include "fastdeploy/vision/common/processors/crop.h"
#include "fastdeploy/vision/common/processors/hwc2chw.h"
#include "fastdeploy/vision/common/processors/limit_by_stride.h"
#include "fastdeploy/vision/common/processors/limit_long.h"
#include "fastdeploy/vision/common/processors/limit_short.h"
#include "fastdeploy/vision/common/processors/normalize.h"
#include "fastdeploy/vision/common/processors/normalize_and_permute.h"
@@ -36,89 +35,12 @@
namespace fastdeploy {
namespace vision {
inline void FuseNormalizeCast(
std::vector<std::shared_ptr<Processor>>* processors) {
// Fuse Normalize and Cast<Float>
int cast_index = -1;
for (size_t i = 0; i < processors->size(); ++i) {
if ((*processors)[i]->Name() == "Cast") {
if (i == 0) {
continue;
}
if ((*processors)[i - 1]->Name() != "Normalize" &&
(*processors)[i - 1]->Name() != "NormalizeAndPermute") {
continue;
}
cast_index = i;
}
}
if (cast_index < 0) {
return;
}
void FuseTransforms(std::vector<std::shared_ptr<Processor>>* processors);
if (dynamic_cast<Cast*>((*processors)[cast_index].get())->GetDtype() !=
"float") {
return;
}
processors->erase(processors->begin() + cast_index);
FDINFO << (*processors)[cast_index - 1]->Name() << " and Cast are fused to "
<< (*processors)[cast_index - 1]->Name()
<< " in preprocessing pipeline." << std::endl;
}
inline void FuseNormalizeHWC2CHW(
std::vector<std::shared_ptr<Processor>>* processors) {
// Fuse Normalize and HWC2CHW to NormalizeAndPermute
int hwc2chw_index = -1;
for (size_t i = 0; i < processors->size(); ++i) {
if ((*processors)[i]->Name() == "HWC2CHW") {
if (i == 0) {
continue;
}
if ((*processors)[i - 1]->Name() != "Normalize") {
continue;
}
hwc2chw_index = i;
}
}
if (hwc2chw_index < 0) {
return;
}
// Get alpha and beta of Normalize
std::vector<float> alpha =
dynamic_cast<Normalize*>((*processors)[hwc2chw_index - 1].get())
->GetAlpha();
std::vector<float> beta =
dynamic_cast<Normalize*>((*processors)[hwc2chw_index - 1].get())
->GetBeta();
// Delete Normalize and HWC2CHW
processors->erase(processors->begin() + hwc2chw_index);
processors->erase(processors->begin() + hwc2chw_index - 1);
// Add NormalizeAndPermute
std::vector<float> mean({0.0, 0.0, 0.0});
std::vector<float> std({1.0, 1.0, 1.0});
processors->push_back(std::make_shared<NormalizeAndPermute>(mean, std));
// Set alpha and beta
auto processor = dynamic_cast<NormalizeAndPermute*>(
(*processors)[hwc2chw_index - 1].get());
processor->SetAlpha(alpha);
processor->SetBeta(beta);
FDINFO << "Normalize and HWC2CHW are fused to NormalizeAndPermute "
" in preprocessing pipeline."
<< std::endl;
}
inline void FuseTransforms(
std::vector<std::shared_ptr<Processor>>* processors) {
FuseNormalizeCast(processors);
FuseNormalizeHWC2CHW(processors);
}
// Fuse Normalize + Cast(Float) to Normalize
void FuseNormalizeCast(std::vector<std::shared_ptr<Processor>>* processors);
// Fuse Normalize + HWC2CHW to NormalizeAndPermute
void FuseNormalizeHWC2CHW(std::vector<std::shared_ptr<Processor>>* processors);
} // namespace vision
} // namespace fastdeploy

View File

@@ -27,7 +27,7 @@ PPTracking::PPTracking(const std::string& model_file,
const ModelFormat& model_format){
config_file_=config_file;
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT};
runtime_option = custom_option;
runtime_option.model_format = model_format;
@@ -148,6 +148,8 @@ bool PPTracking::BuildPreprocessPipelineFromConfig(){
}
}
processors_.push_back(std::make_shared<HWC2CHW>());
FuseTransforms(&processors_);
return true;
}

View File

@@ -16,11 +16,19 @@ import logging
import os
import sys
from .c_lib_wrap import (ModelFormat, Backend, rknpu2,
FDDataType, TensorInfo, Device,
FDTensor, is_built_with_gpu, is_built_with_ort,
ModelFormat, is_built_with_paddle, is_built_with_trt,
get_default_cuda_directory, )
from .c_lib_wrap import (
ModelFormat,
Backend,
rknpu2,
FDDataType,
TensorInfo,
Device,
is_built_with_gpu,
is_built_with_ort,
ModelFormat,
is_built_with_paddle,
is_built_with_trt,
get_default_cuda_directory, )
from .runtime import Runtime, RuntimeOption
from .model import FastDeployModel

View File

@@ -49,28 +49,28 @@ setup_configs = dict()
setup_configs["ENABLE_PADDLE_FRONTEND"] = os.getenv("ENABLE_PADDLE_FRONTEND",
"ON")
setup_configs["ENABLE_RKNPU2_BACKEND"] = os.getenv("ENABLE_RKNPU2_BACKEND",
"OFF")
"OFF")
setup_configs["ENABLE_ORT_BACKEND"] = os.getenv("ENABLE_ORT_BACKEND", "OFF")
setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND",
"OFF")
setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND",
"OFF")
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND",
"OFF")
setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", "OFF")
setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF")
setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF")
setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF")
setup_configs["ENABLE_FLYCV"] = os.getenv("ENABLE_FLYCV", "OFF")
setup_configs["ENABLE_TEXT"] = os.getenv("ENABLE_TEXT", "OFF")
setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF")
setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF")
setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF")
setup_configs["BUILD_ON_JETSON"] = os.getenv("BUILD_ON_JETSON", "OFF")
setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED")
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY",
"/usr/local/cuda")
setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda")
setup_configs["LIBRARY_NAME"] = PACKAGE_NAME
setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main"
setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "")
setup_configs["ORT_DIRECTORY"] = os.getenv("ORT_DIRECTORY", "")
setup_configs["RKNN2_TARGET_SOC"] = os.getenv("RKNN2_TARGET_SOC", "")
if setup_configs["WITH_GPU"] == "ON" or setup_configs[
@@ -99,8 +99,7 @@ extras_require = {}
# Default value is set to TRUE\1 to keep the settings same as the current ones.
# However going forward the recomemded way to is to set this to False\0
USE_MSVC_STATIC_RUNTIME = bool(
os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1')
USE_MSVC_STATIC_RUNTIME = bool(os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1')
ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'paddle2onnx')
################################################################################
# Version
@@ -130,8 +129,7 @@ assert CMAKE, 'Could not find "cmake" executable!'
@contextmanager
def cd(path):
if not os.path.isabs(path):
raise RuntimeError('Can only cd to absolute path, got: {}'.format(
path))
raise RuntimeError('Can only cd to absolute path, got: {}'.format(path))
orig_path = os.getcwd()
os.chdir(path)
try: