diff --git a/fastdeploy/core/fd_tensor.cc b/fastdeploy/core/fd_tensor.cc index 8b739d844..ff4ae61aa 100644 --- a/fastdeploy/core/fd_tensor.cc +++ b/fastdeploy/core/fd_tensor.cc @@ -131,6 +131,7 @@ void FDTensor::Resize(const std::vector& new_shape, const FDDataType& data_type, const std::string& tensor_name, const Device& new_device) { + external_data_ptr = nullptr; name = tensor_name; device = new_device; dtype = data_type; diff --git a/fastdeploy/core/fd_tensor.h b/fastdeploy/core/fd_tensor.h index 1619fe271..ef7fbff41 100644 --- a/fastdeploy/core/fd_tensor.h +++ b/fastdeploy/core/fd_tensor.h @@ -93,6 +93,12 @@ struct FASTDEPLOY_DECL FDTensor { // Total number of elements in this tensor int Numel() const; + // Get shape of FDTensor + std::vector Shape() const { return shape; } + + // Get dtype of FDTensor + FDDataType Dtype() const { return dtype; } + void Resize(size_t nbytes); void Resize(const std::vector& new_shape); diff --git a/fastdeploy/pybind/fd_tensor.cc b/fastdeploy/pybind/fd_tensor.cc new file mode 100644 index 000000000..2e14b6d18 --- /dev/null +++ b/fastdeploy/pybind/fd_tensor.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { + +void BindFDTensor(pybind11::module& m) { + pybind11::class_(m, "FDTensor") + .def(pybind11::init<>(), "Default Constructor") + .def_readwrite("name", &FDTensor::name) + .def_readonly("shape", &FDTensor::shape) + .def_readonly("dtype", &FDTensor::dtype) + .def_readonly("device", &FDTensor::device) + .def("numpy", [](FDTensor& self) { + return TensorToPyArray(self); + }) + .def("from_numpy", [](FDTensor& self, pybind11::array& pyarray, bool share_buffer = false) { + PyArrayToTensor(pyarray, &self, share_buffer); + }); +} + +} // namespace fastdeploy diff --git a/fastdeploy/pybind/main.cc.in b/fastdeploy/pybind/main.cc.in index 74fe90433..97aafc64a 100644 --- a/fastdeploy/pybind/main.cc.in +++ b/fastdeploy/pybind/main.cc.in @@ -16,6 +16,7 @@ namespace fastdeploy { +void BindFDTensor(pybind11::module&); void BindRuntime(pybind11::module&); void BindFDModel(pybind11::module&); void BindVision(pybind11::module&); @@ -70,7 +71,7 @@ void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor, data_shape.insert(data_shape.begin(), pyarray.shape(), pyarray.shape() + pyarray.ndim()); if (share_buffer) { - tensor-> SetExternalData(data_shape, dtype, + tensor->SetExternalData(data_shape, dtype, pyarray.mutable_data()); } else { tensor->Resize(data_shape, dtype); @@ -80,6 +81,7 @@ void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor, void PyArrayToTensorList(std::vector& pyarrays, std::vector* tensors, bool share_buffer) { + tensors->resize(pyarrays.size()); for(auto i = 0; i < pyarrays.size(); ++i) { PyArrayToTensor(pyarrays[i], &(*tensors)[i], share_buffer); } @@ -88,7 +90,7 @@ void PyArrayToTensorList(std::vector& pyarrays, std::vector #include #include +#include #include diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc index a694be970..11cf9bf4e 100755 --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -162,6 +162,25 @@ void BindRuntime(pybind11::module& m) { } return results; }) + .def("infer", [](Runtime& self, std::map& data) { + std::vector inputs; + inputs.reserve(data.size()); + for (auto iter = data.begin(); iter != data.end(); ++iter) { + FDTensor tensor; + tensor.SetExternalData(iter->second.Shape(), iter->second.Dtype(), iter->second.Data(), iter->second.device); + tensor.name = iter->first; + inputs.push_back(tensor); + } + std::vector outputs; + if (!self.Infer(inputs, &outputs)) { + pybind11::eval("raise Exception('Failed to inference with Runtime.')"); + } + return outputs; + }) + .def("infer", [](Runtime& self, std::vector& inputs) { + std::vector outputs; + return self.Infer(inputs, &outputs); + }) .def("num_inputs", &Runtime::NumInputs) .def("num_outputs", &Runtime::NumOutputs) .def("get_input_info", &Runtime::GetInputInfo) @@ -202,33 +221,6 @@ void BindRuntime(pybind11::module& m) { .value("FP64", FDDataType::FP64) .value("UINT8", FDDataType::UINT8); - pybind11::class_(m, "FDTensor", pybind11::buffer_protocol()) - .def(pybind11::init()) - .def("cpu_data", - [](FDTensor& self) { - auto ptr = self.CpuData(); - auto numel = self.Numel(); - auto dtype = FDDataTypeToNumpyDataType(self.dtype); - auto base = pybind11::array(dtype, self.shape); - return pybind11::array(dtype, self.shape, ptr, base); - }) - .def("resize", static_cast(&FDTensor::Resize)) - .def("resize", - static_cast&)>( - &FDTensor::Resize)) - .def( - "resize", - [](FDTensor& self, const std::vector& shape, - const FDDataType& dtype, const std::string& name, - const Device& device) { self.Resize(shape, dtype, name, device); }) - .def("numel", &FDTensor::Numel) - .def("nbytes", &FDTensor::Nbytes) - .def_readwrite("name", &FDTensor::name) - .def_readwrite("is_pinned_memory", &FDTensor::is_pinned_memory) - .def_readonly("shape", &FDTensor::shape) - .def_readonly("dtype", &FDTensor::dtype) - .def_readonly("device", &FDTensor::device); - m.def("get_available_backends", []() { return GetAvailableBackends(); }); } diff --git a/fastdeploy/vision/common/processors/limit_long.cc b/fastdeploy/vision/common/processors/limit_long.cc deleted file mode 100644 index 7021f131b..000000000 --- a/fastdeploy/vision/common/processors/limit_long.cc +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "fastdeploy/vision/common/processors/limit_long.h" - -namespace fastdeploy { -namespace vision { - -bool LimitLong::ImplByOpenCV(Mat* mat) { - cv::Mat* im = mat->GetOpenCVMat(); - int origin_w = im->cols; - int origin_h = im->rows; - int im_size_max = std::max(origin_w, origin_h); - int target = im_size_max; - if (max_long_ > 0 && im_size_max > max_long_) { - target = max_long_; - } else if (min_long_ > 0 && im_size_max < min_long_) { - target = min_long_; - } - if (target != im_size_max) { - double scale = - static_cast(target) / static_cast(im_size_max); - cv::resize(*im, *im, cv::Size(), scale, scale, interp_); - mat->SetWidth(im->cols); - mat->SetHeight(im->rows); - } - return true; -} - -#ifdef ENABLE_FLYCV -bool LimitLong::ImplByFlyCV(Mat* mat) { - fcv::Mat* im = mat->GetFlyCVMat(); - int origin_w = im->width(); - int origin_h = im->height(); - int im_size_max = std::max(origin_w, origin_h); - int target = im_size_max; - if (max_long_ > 0 && im_size_max > max_long_) { - target = max_long_; - } else if (min_long_ > 0 && im_size_max < min_long_) { - target = min_long_; - } - if (target != im_size_max) { - double scale = - static_cast(target) / static_cast(im_size_max); - if (fabs(scale - 1.0) < 1e-06) { - return true; - } - auto interp_method = fcv::InterpolationType::INTER_LINEAR; - if (interp_ == 0) { - interp_method = fcv::InterpolationType::INTER_NEAREST; - } else if (interp_ == 1) { - interp_method = fcv::InterpolationType::INTER_LINEAR; - } else if (interp_ == 2) { - interp_method = fcv::InterpolationType::INTER_CUBIC; - } else { - FDERROR << "LimitLong: Only support interp_ be 0/1/2 with FlyCV, but " - "now it's " - << interp_ << "." << std::endl; - return false; - } - fcv::Mat new_im; - fcv::resize(*im, new_im, fcv::Size(), scale, scale, interp_method); - mat->SetMat(new_im); - mat->SetWidth(new_im.width()); - mat->SetHeight(new_im.height()); - } - return true; -} -#endif - -bool LimitLong::Run(Mat* mat, int max_long, int min_long, int interp, - ProcLib lib) { - auto l = LimitLong(max_long, min_long, interp); - return l(mat, lib); -} -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/limit_long.h b/fastdeploy/vision/common/processors/limit_long.h deleted file mode 100644 index 49055973d..000000000 --- a/fastdeploy/vision/common/processors/limit_long.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "fastdeploy/vision/common/processors/base.h" - -namespace fastdeploy { -namespace vision { - -class FASTDEPLOY_DECL LimitLong : public Processor { - public: - explicit LimitLong(int max_long = -1, int min_long = -1, int interp = 1) { - max_long_ = max_long; - min_long_ = min_long; - interp_ = interp; - } - - // Limit the long edge of image. - // If the long edge is larger than max_long_, resize the long edge - // to max_long_, while scale the short edge proportionally. - // If the long edge is smaller than min_long_, resize the long edge - // to min_long_, while scale the short edge proportionally. - bool ImplByOpenCV(Mat* mat); -#ifdef ENABLE_FLYCV - bool ImplByFlyCV(Mat* mat); -#endif - std::string Name() { return "LimitLong"; } - - static bool Run(Mat* mat, int max_long = -1, int min_long = -1, - int interp = 1, ProcLib lib = ProcLib::DEFAULT); - int GetMaxLong() const { return max_long_; } - - private: - int max_long_; - int min_long_; - int interp_; -}; -} // namespace vision -} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/limit_short.cc b/fastdeploy/vision/common/processors/limit_short.cc index d0f0697c8..c4e909744 100644 --- a/fastdeploy/vision/common/processors/limit_short.cc +++ b/fastdeploy/vision/common/processors/limit_short.cc @@ -65,7 +65,7 @@ bool LimitShort::ImplByFlyCV(Mat* mat) { } else if (interp_ == 2) { interp_method = fcv::InterpolationType::INTER_CUBIC; } else { - FDERROR << "LimitLong: Only support interp_ be 0/1/2 with FlyCV, but " + FDERROR << "LimitShort: Only support interp_ be 0/1/2 with FlyCV, but " "now it's " << interp_ << "." << std::endl; return false; diff --git a/fastdeploy/vision/common/processors/mat.cc b/fastdeploy/vision/common/processors/mat.cc index e2a64ea04..7ef4f9d70 100644 --- a/fastdeploy/vision/common/processors/mat.cc +++ b/fastdeploy/vision/common/processors/mat.cc @@ -174,5 +174,18 @@ Mat Mat::Create(int height, int width, int channels, return mat; } +FDMat WrapMat(const cv::Mat& image) { + FDMat mat(image); + return mat; +} + +std::vector WrapMat(const std::vector& images) { + std::vector mats; + for (size_t i = 0; i < images.size(); ++i) { + mats.emplace_back(FDMat(images[i])); + } + return mats; +} + } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/mat.h b/fastdeploy/vision/common/processors/mat.h index 5e618057c..525370043 100644 --- a/fastdeploy/vision/common/processors/mat.h +++ b/fastdeploy/vision/common/processors/mat.h @@ -147,5 +147,15 @@ struct FASTDEPLOY_DECL Mat { FDDataType type, void* data, ProcLib lib); }; +typedef Mat FDMat; +/* + * @brief Wrap a cv::Mat to FDMat, there's no memory copy, memory buffer is managed by user + */ +FASTDEPLOY_DECL FDMat WrapMat(const cv::Mat& image); +/* + * Warp a vector to vector, there's no memory copy, memory buffer is managed by user + */ +FASTDEPLOY_DECL std::vector WrapMat(const std::vector& images); + } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/resize.cc b/fastdeploy/vision/common/processors/resize.cc index 28488c2cd..da9104059 100644 --- a/fastdeploy/vision/common/processors/resize.cc +++ b/fastdeploy/vision/common/processors/resize.cc @@ -79,7 +79,7 @@ bool Resize::ImplByFlyCV(Mat* mat) { } else if (interp_ == 2) { interp_method = fcv::InterpolationType::INTER_CUBIC; } else { - FDERROR << "LimitLong: Only support interp_ be 0/1/2 with FlyCV, but " + FDERROR << "Resize: Only support interp_ be 0/1/2 with FlyCV, but " "now it's " << interp_ << "." << std::endl; return false; diff --git a/fastdeploy/vision/common/processors/transform.cc b/fastdeploy/vision/common/processors/transform.cc new file mode 100644 index 000000000..8d440b9c6 --- /dev/null +++ b/fastdeploy/vision/common/processors/transform.cc @@ -0,0 +1,106 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/common/processors/transform.h" + +namespace fastdeploy { +namespace vision { + +void FuseNormalizeCast( + std::vector>* processors) { + // Fuse Normalize and Cast + int cast_index = -1; + for (size_t i = 0; i < processors->size(); ++i) { + if ((*processors)[i]->Name() == "Cast") { + if (i == 0) { + continue; + } + if ((*processors)[i - 1]->Name() != "Normalize" && + (*processors)[i - 1]->Name() != "NormalizeAndPermute") { + continue; + } + cast_index = i; + } + } + if (cast_index < 0) { + return; + } + + if (dynamic_cast((*processors)[cast_index].get())->GetDtype() != + "float") { + return; + } + processors->erase(processors->begin() + cast_index); + FDINFO << (*processors)[cast_index - 1]->Name() << " and Cast are fused to " + << (*processors)[cast_index - 1]->Name() + << " in preprocessing pipeline." << std::endl; +} + +void FuseNormalizeHWC2CHW( + std::vector>* processors) { + // Fuse Normalize and HWC2CHW to NormalizeAndPermute + int hwc2chw_index = -1; + for (size_t i = 0; i < processors->size(); ++i) { + if ((*processors)[i]->Name() == "HWC2CHW") { + if (i == 0) { + continue; + } + if ((*processors)[i - 1]->Name() != "Normalize") { + continue; + } + hwc2chw_index = i; + } + } + + if (hwc2chw_index < 0) { + return; + } + + // Get alpha and beta of Normalize + std::vector alpha = + dynamic_cast((*processors)[hwc2chw_index - 1].get()) + ->GetAlpha(); + std::vector beta = + dynamic_cast((*processors)[hwc2chw_index - 1].get()) + ->GetBeta(); + + // Delete Normalize and HWC2CHW + processors->erase(processors->begin() + hwc2chw_index); + processors->erase(processors->begin() + hwc2chw_index - 1); + + // Add NormalizeAndPermute + std::vector mean({0.0, 0.0, 0.0}); + std::vector std({1.0, 1.0, 1.0}); + processors->push_back(std::make_shared(mean, std)); + + // Set alpha and beta + auto processor = dynamic_cast( + (*processors)[hwc2chw_index - 1].get()); + + processor->SetAlpha(alpha); + processor->SetBeta(beta); + FDINFO << "Normalize and HWC2CHW are fused to NormalizeAndPermute " + " in preprocessing pipeline." + << std::endl; +} + +void FuseTransforms( + std::vector>* processors) { + FuseNormalizeCast(processors); + FuseNormalizeHWC2CHW(processors); +} + + +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/transform.h b/fastdeploy/vision/common/processors/transform.h index 9054ade55..53f7ffd63 100644 --- a/fastdeploy/vision/common/processors/transform.h +++ b/fastdeploy/vision/common/processors/transform.h @@ -21,7 +21,6 @@ #include "fastdeploy/vision/common/processors/crop.h" #include "fastdeploy/vision/common/processors/hwc2chw.h" #include "fastdeploy/vision/common/processors/limit_by_stride.h" -#include "fastdeploy/vision/common/processors/limit_long.h" #include "fastdeploy/vision/common/processors/limit_short.h" #include "fastdeploy/vision/common/processors/normalize.h" #include "fastdeploy/vision/common/processors/normalize_and_permute.h" @@ -36,89 +35,12 @@ namespace fastdeploy { namespace vision { -inline void FuseNormalizeCast( - std::vector>* processors) { - // Fuse Normalize and Cast - int cast_index = -1; - for (size_t i = 0; i < processors->size(); ++i) { - if ((*processors)[i]->Name() == "Cast") { - if (i == 0) { - continue; - } - if ((*processors)[i - 1]->Name() != "Normalize" && - (*processors)[i - 1]->Name() != "NormalizeAndPermute") { - continue; - } - cast_index = i; - } - } - if (cast_index < 0) { - return; - } +void FuseTransforms(std::vector>* processors); - if (dynamic_cast((*processors)[cast_index].get())->GetDtype() != - "float") { - return; - } - processors->erase(processors->begin() + cast_index); - FDINFO << (*processors)[cast_index - 1]->Name() << " and Cast are fused to " - << (*processors)[cast_index - 1]->Name() - << " in preprocessing pipeline." << std::endl; -} - -inline void FuseNormalizeHWC2CHW( - std::vector>* processors) { - // Fuse Normalize and HWC2CHW to NormalizeAndPermute - int hwc2chw_index = -1; - for (size_t i = 0; i < processors->size(); ++i) { - if ((*processors)[i]->Name() == "HWC2CHW") { - if (i == 0) { - continue; - } - if ((*processors)[i - 1]->Name() != "Normalize") { - continue; - } - hwc2chw_index = i; - } - } - - if (hwc2chw_index < 0) { - return; - } - - // Get alpha and beta of Normalize - std::vector alpha = - dynamic_cast((*processors)[hwc2chw_index - 1].get()) - ->GetAlpha(); - std::vector beta = - dynamic_cast((*processors)[hwc2chw_index - 1].get()) - ->GetBeta(); - - // Delete Normalize and HWC2CHW - processors->erase(processors->begin() + hwc2chw_index); - processors->erase(processors->begin() + hwc2chw_index - 1); - - // Add NormalizeAndPermute - std::vector mean({0.0, 0.0, 0.0}); - std::vector std({1.0, 1.0, 1.0}); - processors->push_back(std::make_shared(mean, std)); - - // Set alpha and beta - auto processor = dynamic_cast( - (*processors)[hwc2chw_index - 1].get()); - - processor->SetAlpha(alpha); - processor->SetBeta(beta); - FDINFO << "Normalize and HWC2CHW are fused to NormalizeAndPermute " - " in preprocessing pipeline." - << std::endl; -} - -inline void FuseTransforms( - std::vector>* processors) { - FuseNormalizeCast(processors); - FuseNormalizeHWC2CHW(processors); -} +// Fuse Normalize + Cast(Float) to Normalize +void FuseNormalizeCast(std::vector>* processors); +// Fuse Normalize + HWC2CHW to NormalizeAndPermute +void FuseNormalizeHWC2CHW(std::vector>* processors); } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/tracking/pptracking/model.cc b/fastdeploy/vision/tracking/pptracking/model.cc index 2047497b9..a4e6c175b 100644 --- a/fastdeploy/vision/tracking/pptracking/model.cc +++ b/fastdeploy/vision/tracking/pptracking/model.cc @@ -27,7 +27,7 @@ PPTracking::PPTracking(const std::string& model_file, const ModelFormat& model_format){ config_file_=config_file; valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; - valid_gpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT, Backend::TRT}; runtime_option = custom_option; runtime_option.model_format = model_format; @@ -148,6 +148,8 @@ bool PPTracking::BuildPreprocessPipelineFromConfig(){ } } processors_.push_back(std::make_shared()); + + FuseTransforms(&processors_); return true; } diff --git a/python/fastdeploy/__init__.py b/python/fastdeploy/__init__.py index 36e52c911..b767393f1 100644 --- a/python/fastdeploy/__init__.py +++ b/python/fastdeploy/__init__.py @@ -16,11 +16,19 @@ import logging import os import sys -from .c_lib_wrap import (ModelFormat, Backend, rknpu2, - FDDataType, TensorInfo, Device, - FDTensor, is_built_with_gpu, is_built_with_ort, - ModelFormat, is_built_with_paddle, is_built_with_trt, - get_default_cuda_directory, ) +from .c_lib_wrap import ( + ModelFormat, + Backend, + rknpu2, + FDDataType, + TensorInfo, + Device, + is_built_with_gpu, + is_built_with_ort, + ModelFormat, + is_built_with_paddle, + is_built_with_trt, + get_default_cuda_directory, ) from .runtime import Runtime, RuntimeOption from .model import FastDeployModel diff --git a/python/setup.py b/python/setup.py index 10f57a2cb..2f3183222 100755 --- a/python/setup.py +++ b/python/setup.py @@ -49,28 +49,28 @@ setup_configs = dict() setup_configs["ENABLE_PADDLE_FRONTEND"] = os.getenv("ENABLE_PADDLE_FRONTEND", "ON") setup_configs["ENABLE_RKNPU2_BACKEND"] = os.getenv("ENABLE_RKNPU2_BACKEND", - "OFF") + "OFF") setup_configs["ENABLE_ORT_BACKEND"] = os.getenv("ENABLE_ORT_BACKEND", "OFF") setup_configs["ENABLE_OPENVINO_BACKEND"] = os.getenv("ENABLE_OPENVINO_BACKEND", "OFF") setup_configs["ENABLE_PADDLE_BACKEND"] = os.getenv("ENABLE_PADDLE_BACKEND", "OFF") -setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", - "OFF") +setup_configs["ENABLE_POROS_BACKEND"] = os.getenv("ENABLE_POROS_BACKEND", "OFF") +setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF") +setup_configs["ENABLE_LITE_BACKEND"] = os.getenv("ENABLE_LITE_BACKEND", "OFF") setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF") setup_configs["ENABLE_FLYCV"] = os.getenv("ENABLE_FLYCV", "OFF") setup_configs["ENABLE_TEXT"] = os.getenv("ENABLE_TEXT", "OFF") -setup_configs["ENABLE_TRT_BACKEND"] = os.getenv("ENABLE_TRT_BACKEND", "OFF") setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF") setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF") setup_configs["BUILD_ON_JETSON"] = os.getenv("BUILD_ON_JETSON", "OFF") setup_configs["TRT_DIRECTORY"] = os.getenv("TRT_DIRECTORY", "UNDEFINED") -setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", - "/usr/local/cuda") +setup_configs["CUDA_DIRECTORY"] = os.getenv("CUDA_DIRECTORY", "/usr/local/cuda") setup_configs["LIBRARY_NAME"] = PACKAGE_NAME setup_configs["PY_LIBRARY_NAME"] = PACKAGE_NAME + "_main" setup_configs["OPENCV_DIRECTORY"] = os.getenv("OPENCV_DIRECTORY", "") setup_configs["ORT_DIRECTORY"] = os.getenv("ORT_DIRECTORY", "") + setup_configs["RKNN2_TARGET_SOC"] = os.getenv("RKNN2_TARGET_SOC", "") if setup_configs["WITH_GPU"] == "ON" or setup_configs[ @@ -99,8 +99,7 @@ extras_require = {} # Default value is set to TRUE\1 to keep the settings same as the current ones. # However going forward the recomemded way to is to set this to False\0 -USE_MSVC_STATIC_RUNTIME = bool( - os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1') +USE_MSVC_STATIC_RUNTIME = bool(os.getenv('USE_MSVC_STATIC_RUNTIME', '1') == '1') ONNX_NAMESPACE = os.getenv('ONNX_NAMESPACE', 'paddle2onnx') ################################################################################ # Version @@ -130,8 +129,7 @@ assert CMAKE, 'Could not find "cmake" executable!' @contextmanager def cd(path): if not os.path.isabs(path): - raise RuntimeError('Can only cd to absolute path, got: {}'.format( - path)) + raise RuntimeError('Can only cd to absolute path, got: {}'.format(path)) orig_path = os.getcwd() os.chdir(path) try: