mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[CVCUDA] Utilize CV-CUDA batch processing function (#1223)
* norm and permute batch processing * move cache to mat, batch processors * get batched tensor logic, resize on cpu logic * fix cpu compile error * remove vector mat api * nits * add comments * nits * fix batch size * move initial resize on cpu option to use_cuda api * fix pybind * processor manager pybind * rename mat and matbatch * move initial resize on cpu to ppcls preprocessor --------- Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
@@ -12,11 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
#include "fastdeploy/core/float16.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
#include "fastdeploy/core/float16.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#ifdef WITH_GPU
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
@@ -142,6 +143,9 @@ void FDTensor::Resize(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type,
|
||||
const std::string& tensor_name,
|
||||
const Device& new_device) {
|
||||
if (device != new_device) {
|
||||
FreeFn();
|
||||
}
|
||||
external_data_ptr = nullptr;
|
||||
name = tensor_name;
|
||||
device = new_device;
|
||||
@@ -269,7 +273,8 @@ bool FDTensor::ReallocFn(size_t nbytes) {
|
||||
}
|
||||
return buffer_ != nullptr;
|
||||
#else
|
||||
FDASSERT(false, "The FastDeploy FDTensor allocator didn't compile under "
|
||||
FDASSERT(false,
|
||||
"The FastDeploy FDTensor allocator didn't compile under "
|
||||
"-DWITH_GPU=ON,"
|
||||
"so this is an unexpected problem happend.");
|
||||
#endif
|
||||
@@ -285,7 +290,8 @@ bool FDTensor::ReallocFn(size_t nbytes) {
|
||||
}
|
||||
return buffer_ != nullptr;
|
||||
#else
|
||||
FDASSERT(false, "The FastDeploy FDTensor allocator didn't compile under "
|
||||
FDASSERT(false,
|
||||
"The FastDeploy FDTensor allocator didn't compile under "
|
||||
"-DWITH_GPU=ON,"
|
||||
"so this is an unexpected problem happend.");
|
||||
#endif
|
||||
@@ -296,8 +302,7 @@ bool FDTensor::ReallocFn(size_t nbytes) {
|
||||
}
|
||||
|
||||
void FDTensor::FreeFn() {
|
||||
if (external_data_ptr != nullptr)
|
||||
external_data_ptr = nullptr;
|
||||
if (external_data_ptr != nullptr) external_data_ptr = nullptr;
|
||||
if (buffer_ != nullptr) {
|
||||
if (device == Device::GPU) {
|
||||
#ifdef WITH_GPU
|
||||
@@ -386,8 +391,11 @@ FDTensor::FDTensor(const Scalar& scalar) {
|
||||
}
|
||||
|
||||
FDTensor::FDTensor(const FDTensor& other)
|
||||
: shape(other.shape), name(other.name), dtype(other.dtype),
|
||||
device(other.device), external_data_ptr(other.external_data_ptr),
|
||||
: shape(other.shape),
|
||||
name(other.name),
|
||||
dtype(other.dtype),
|
||||
device(other.device),
|
||||
external_data_ptr(other.external_data_ptr),
|
||||
device_id(other.device_id) {
|
||||
// Copy buffer
|
||||
if (other.buffer_ == nullptr) {
|
||||
@@ -401,9 +409,12 @@ FDTensor::FDTensor(const FDTensor& other)
|
||||
}
|
||||
|
||||
FDTensor::FDTensor(FDTensor&& other)
|
||||
: buffer_(other.buffer_), shape(std::move(other.shape)),
|
||||
name(std::move(other.name)), dtype(other.dtype),
|
||||
external_data_ptr(other.external_data_ptr), device(other.device),
|
||||
: buffer_(other.buffer_),
|
||||
shape(std::move(other.shape)),
|
||||
name(std::move(other.name)),
|
||||
dtype(other.dtype),
|
||||
external_data_ptr(other.external_data_ptr),
|
||||
device(other.device),
|
||||
device_id(other.device_id) {
|
||||
other.name = "";
|
||||
// Note(zhoushunjie): Avoid double free.
|
||||
|
@@ -15,33 +15,9 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindPaddleClas(pybind11::module& m) {
|
||||
pybind11::class_<vision::classification::PaddleClasPreprocessor>(
|
||||
m, "PaddleClasPreprocessor")
|
||||
pybind11::class_<vision::classification::PaddleClasPreprocessor,
|
||||
vision::ProcessorManager>(m, "PaddleClasPreprocessor")
|
||||
.def(pybind11::init<std::string>())
|
||||
.def("run",
|
||||
[](vision::classification::PaddleClasPreprocessor& self,
|
||||
std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||
}
|
||||
std::vector<FDTensor> outputs;
|
||||
if (!self.Run(&images, &outputs)) {
|
||||
throw std::runtime_error(
|
||||
"Failed to preprocess the input data in "
|
||||
"PaddleClasPreprocessor.");
|
||||
}
|
||||
if (!self.CudaUsed()) {
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
outputs[i].StopSharing();
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
})
|
||||
.def("use_cuda",
|
||||
[](vision::classification::PaddleClasPreprocessor& self,
|
||||
bool enable_cv_cuda = false,
|
||||
int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); })
|
||||
.def("disable_normalize",
|
||||
[](vision::classification::PaddleClasPreprocessor& self) {
|
||||
self.DisableNormalize();
|
||||
@@ -49,6 +25,10 @@ void BindPaddleClas(pybind11::module& m) {
|
||||
.def("disable_permute",
|
||||
[](vision::classification::PaddleClasPreprocessor& self) {
|
||||
self.DisablePermute();
|
||||
})
|
||||
.def("initial_resize_on_cpu",
|
||||
[](vision::classification::PaddleClasPreprocessor& self, bool v) {
|
||||
self.InitialResizeOnCpu(v);
|
||||
});
|
||||
|
||||
pybind11::class_<vision::classification::PaddleClasPostprocessor>(
|
||||
|
@@ -100,32 +100,23 @@ void PaddleClasPreprocessor::DisablePermute() {
|
||||
}
|
||||
}
|
||||
|
||||
bool PaddleClasPreprocessor::Apply(std::vector<FDMat>* images,
|
||||
bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
for (size_t j = 0; j < processors_.size(); ++j) {
|
||||
bool ret = false;
|
||||
ret = (*(processors_[j].get()))(&((*images)[i]));
|
||||
if (!ret) {
|
||||
FDERROR << "Failed to processs image:" << i << " in "
|
||||
<< processors_[j]->Name() << "." << std::endl;
|
||||
return false;
|
||||
ProcLib lib = ProcLib::DEFAULT;
|
||||
if (initial_resize_on_cpu_ && j == 0 &&
|
||||
processors_[j]->Name().find("Resize") == 0) {
|
||||
lib = ProcLib::OPENCV;
|
||||
}
|
||||
if (!(*(processors_[j].get()))(image_batch, lib)) {
|
||||
FDERROR << "Failed to processs image in " << processors_[j]->Name() << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
outputs->resize(1);
|
||||
// Concat all the preprocessed data to a batch tensor
|
||||
std::vector<FDTensor> tensors(images->size());
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
(*images)[i].ShareWithTensor(&(tensors[i]));
|
||||
tensors[i].ExpandDim(0);
|
||||
}
|
||||
if (tensors.size() == 1) {
|
||||
(*outputs)[0] = std::move(tensors[0]);
|
||||
} else {
|
||||
function::Concat(tensors, &((*outputs)[0]), 0);
|
||||
}
|
||||
(*outputs)[0] = std::move(*(image_batch->Tensor()));
|
||||
(*outputs)[0].device_id = DeviceId();
|
||||
return true;
|
||||
}
|
||||
|
@@ -33,11 +33,11 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
|
||||
|
||||
/** \brief Process the input image and prepare input tensors for runtime
|
||||
*
|
||||
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||
* \param[in] image_batch The input image batch
|
||||
* \param[in] outputs The output tensors which will feed in runtime
|
||||
* \return true if the preprocess successed, otherwise false
|
||||
*/
|
||||
virtual bool Apply(std::vector<FDMat>* images,
|
||||
virtual bool Apply(FDMatBatch* image_batch,
|
||||
std::vector<FDTensor>* outputs);
|
||||
|
||||
/// This function will disable normalize in preprocessing step.
|
||||
@@ -45,6 +45,14 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
|
||||
/// This function will disable hwc2chw in preprocessing step.
|
||||
void DisablePermute();
|
||||
|
||||
/** \brief When the initial operator is Resize, and input image size is large,
|
||||
* maybe it's better to run resize on CPU, because the HostToDevice memcpy
|
||||
* is time consuming. Set this true to run the initial resize on CPU.
|
||||
*
|
||||
* \param[in] v ture or false
|
||||
*/
|
||||
void InitialResizeOnCpu(bool v) { initial_resize_on_cpu_ = v; }
|
||||
|
||||
private:
|
||||
bool BuildPreprocessPipelineFromConfig();
|
||||
std::vector<std::shared_ptr<Processor>> processors_;
|
||||
@@ -54,6 +62,7 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager {
|
||||
bool disable_normalize_ = false;
|
||||
// read config file
|
||||
std::string config_file_;
|
||||
bool initial_resize_on_cpu_ = false;
|
||||
};
|
||||
|
||||
} // namespace classification
|
||||
|
@@ -20,7 +20,7 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
bool Processor::operator()(Mat* mat, ProcLib lib) {
|
||||
bool Processor::operator()(FDMat* mat, ProcLib lib) {
|
||||
ProcLib target = lib;
|
||||
if (lib == ProcLib::DEFAULT) {
|
||||
target = DefaultProcLib::default_lib;
|
||||
@@ -52,39 +52,38 @@ bool Processor::operator()(Mat* mat, ProcLib lib) {
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
FDTensor* Processor::UpdateAndGetCachedTensor(
|
||||
const std::vector<int64_t>& new_shape, const FDDataType& data_type,
|
||||
const std::string& tensor_name, const Device& new_device,
|
||||
const bool& use_pinned_memory) {
|
||||
if (cached_tensors_.count(tensor_name) == 0) {
|
||||
cached_tensors_[tensor_name] = FDTensor();
|
||||
bool Processor::operator()(FDMatBatch* mat_batch, ProcLib lib) {
|
||||
ProcLib target = lib;
|
||||
if (lib == ProcLib::DEFAULT) {
|
||||
target = DefaultProcLib::default_lib;
|
||||
}
|
||||
cached_tensors_[tensor_name].is_pinned_memory = use_pinned_memory;
|
||||
cached_tensors_[tensor_name].Resize(new_shape, data_type, tensor_name,
|
||||
new_device);
|
||||
return &cached_tensors_[tensor_name];
|
||||
}
|
||||
|
||||
FDTensor* Processor::CreateCachedGpuInputTensor(
|
||||
Mat* mat, const std::string& tensor_name) {
|
||||
if (target == ProcLib::FLYCV) {
|
||||
#ifdef ENABLE_FLYCV
|
||||
return ImplByFlyCV(mat_batch);
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with FlyCV.");
|
||||
#endif
|
||||
} else if (target == ProcLib::CUDA) {
|
||||
#ifdef WITH_GPU
|
||||
FDTensor* src = mat->Tensor();
|
||||
if (src->device == Device::GPU) {
|
||||
return src;
|
||||
} else if (src->device == Device::CPU) {
|
||||
FDTensor* tensor = UpdateAndGetCachedTensor(src->Shape(), src->Dtype(),
|
||||
tensor_name, Device::GPU);
|
||||
FDASSERT(cudaMemcpyAsync(tensor->Data(), src->Data(), tensor->Nbytes(),
|
||||
cudaMemcpyHostToDevice, mat->Stream()) == 0,
|
||||
"[ERROR] Error occurs while copy memory from CPU to GPU.");
|
||||
return tensor;
|
||||
} else {
|
||||
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
|
||||
}
|
||||
FDASSERT(
|
||||
mat_batch->Stream() != nullptr,
|
||||
"CUDA processor requires cuda stream, please set stream for mat_batch");
|
||||
return ImplByCuda(mat_batch);
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
|
||||
#endif
|
||||
return nullptr;
|
||||
} else if (target == ProcLib::CVCUDA) {
|
||||
#ifdef ENABLE_CVCUDA
|
||||
FDASSERT(mat_batch->Stream() != nullptr,
|
||||
"CV-CUDA processor requires cuda stream, please set stream for "
|
||||
"mat_batch");
|
||||
return ImplByCvCuda(mat_batch);
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with CV-CUDA.");
|
||||
#endif
|
||||
}
|
||||
// DEFAULT & OPENCV
|
||||
return ImplByOpenCV(mat_batch);
|
||||
}
|
||||
|
||||
void EnableFlyCV() {
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
#include "fastdeploy/vision/common/processors/mat_batch.h"
|
||||
#include "opencv2/highgui/highgui.hpp"
|
||||
#include "opencv2/imgproc/imgproc.hpp"
|
||||
#include <unordered_map>
|
||||
@@ -46,46 +47,63 @@ class FASTDEPLOY_DECL Processor {
|
||||
|
||||
virtual std::string Name() = 0;
|
||||
|
||||
virtual bool ImplByOpenCV(Mat* mat) {
|
||||
virtual bool ImplByOpenCV(FDMat* mat) {
|
||||
FDERROR << Name() << " Not Implement Yet." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool ImplByFlyCV(Mat* mat) {
|
||||
virtual bool ImplByOpenCV(FDMatBatch* mat_batch) {
|
||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||
if (ImplByOpenCV(&(*(mat_batch->mats))[i]) != true) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual bool ImplByFlyCV(FDMat* mat) {
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
virtual bool ImplByCuda(Mat* mat) {
|
||||
virtual bool ImplByFlyCV(FDMatBatch* mat_batch) {
|
||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||
if (ImplByFlyCV(&(*(mat_batch->mats))[i]) != true) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual bool ImplByCuda(FDMat* mat) {
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
virtual bool ImplByCvCuda(Mat* mat) {
|
||||
virtual bool ImplByCuda(FDMatBatch* mat_batch) {
|
||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||
if (ImplByCuda(&(*(mat_batch->mats))[i]) != true) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
virtual bool ImplByCvCuda(FDMat* mat) {
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
virtual bool operator()(Mat* mat, ProcLib lib = ProcLib::DEFAULT);
|
||||
virtual bool ImplByCvCuda(FDMatBatch* mat_batch) {
|
||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||
if (ImplByCvCuda(&(*(mat_batch->mats))[i]) != true) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Update and get the cached tensor from the cached_tensors_ map.
|
||||
// The tensor is indexed by a string.
|
||||
// If the tensor doesn't exists in the map, then create a new tensor.
|
||||
// If the tensor exists and shape is getting larger, then realloc the buffer.
|
||||
// If the tensor exists and shape is not getting larger, then return the
|
||||
// cached tensor directly.
|
||||
FDTensor* UpdateAndGetCachedTensor(
|
||||
const std::vector<int64_t>& new_shape, const FDDataType& data_type,
|
||||
const std::string& tensor_name, const Device& new_device = Device::CPU,
|
||||
const bool& use_pinned_memory = false);
|
||||
virtual bool operator()(FDMat* mat, ProcLib lib = ProcLib::DEFAULT);
|
||||
|
||||
// Create an input tensor on GPU and save into cached_tensors_.
|
||||
// If the Mat is on GPU, return the mat->Tensor() directly.
|
||||
// If the Mat is on CPU, then create a cached GPU tensor and copy the mat's
|
||||
// CPU tensor to this new GPU tensor.
|
||||
FDTensor* CreateCachedGpuInputTensor(Mat* mat,
|
||||
const std::string& tensor_name);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, FDTensor> cached_tensors_;
|
||||
virtual bool operator()(FDMatBatch* mat_batch,
|
||||
ProcLib lib = ProcLib::DEFAULT);
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -23,7 +23,7 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
bool CenterCrop::ImplByOpenCV(Mat* mat) {
|
||||
bool CenterCrop::ImplByOpenCV(FDMat* mat) {
|
||||
cv::Mat* im = mat->GetOpenCVMat();
|
||||
int height = static_cast<int>(im->rows);
|
||||
int width = static_cast<int>(im->cols);
|
||||
@@ -42,7 +42,7 @@ bool CenterCrop::ImplByOpenCV(Mat* mat) {
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool CenterCrop::ImplByFlyCV(Mat* mat) {
|
||||
bool CenterCrop::ImplByFlyCV(FDMat* mat) {
|
||||
fcv::Mat* im = mat->GetFlyCVMat();
|
||||
int height = static_cast<int>(im->height());
|
||||
int width = static_cast<int>(im->width());
|
||||
@@ -63,18 +63,15 @@ bool CenterCrop::ImplByFlyCV(Mat* mat) {
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool CenterCrop::ImplByCvCuda(Mat* mat) {
|
||||
bool CenterCrop::ImplByCvCuda(FDMat* mat) {
|
||||
// Prepare input tensor
|
||||
std::string tensor_name = Name() + "_cvcuda_src";
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat);
|
||||
auto src_tensor = CreateCvCudaTensorWrapData(*src);
|
||||
|
||||
// Prepare output tensor
|
||||
tensor_name = Name() + "_cvcuda_dst";
|
||||
FDTensor* dst =
|
||||
UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, src->Dtype(),
|
||||
tensor_name, Device::GPU);
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
|
||||
mat->output_cache->Resize({height_, width_, mat->Channels()}, src->Dtype(),
|
||||
"output_cache", Device::GPU);
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
|
||||
|
||||
int offset_x = static_cast<int>((mat->Width() - width_) / 2);
|
||||
int offset_y = static_cast<int>((mat->Height() - height_) / 2);
|
||||
@@ -82,16 +79,27 @@ bool CenterCrop::ImplByCvCuda(Mat* mat) {
|
||||
NVCVRectI crop_roi = {offset_x, offset_y, width_, height_};
|
||||
crop_op(mat->Stream(), src_tensor, dst_tensor, crop_roi);
|
||||
|
||||
mat->SetTensor(dst);
|
||||
mat->SetTensor(mat->output_cache);
|
||||
mat->SetWidth(width_);
|
||||
mat->SetHeight(height_);
|
||||
mat->device = Device::GPU;
|
||||
mat->mat_type = ProcLib::CVCUDA;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CenterCrop::ImplByCvCuda(FDMatBatch* mat_batch) {
|
||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||
if (ImplByCvCuda(&((*(mat_batch->mats))[i])) != true) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
mat_batch->device = Device::GPU;
|
||||
mat_batch->mat_type = ProcLib::CVCUDA;
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool CenterCrop::Run(Mat* mat, const int& width, const int& height,
|
||||
bool CenterCrop::Run(FDMat* mat, const int& width, const int& height,
|
||||
ProcLib lib) {
|
||||
auto c = CenterCrop(width, height);
|
||||
return c(mat, lib);
|
||||
|
@@ -22,16 +22,17 @@ namespace vision {
|
||||
class FASTDEPLOY_DECL CenterCrop : public Processor {
|
||||
public:
|
||||
CenterCrop(int width, int height) : height_(height), width_(width) {}
|
||||
bool ImplByOpenCV(Mat* mat);
|
||||
bool ImplByOpenCV(FDMat* mat);
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool ImplByFlyCV(Mat* mat);
|
||||
bool ImplByFlyCV(FDMat* mat);
|
||||
#endif
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ImplByCvCuda(Mat* mat);
|
||||
bool ImplByCvCuda(FDMat* mat);
|
||||
bool ImplByCvCuda(FDMatBatch* mat_batch);
|
||||
#endif
|
||||
std::string Name() { return "CenterCrop"; }
|
||||
|
||||
static bool Run(Mat* mat, const int& width, const int& height,
|
||||
static bool Run(FDMat* mat, const int& width, const int& height,
|
||||
ProcLib lib = ProcLib::DEFAULT);
|
||||
|
||||
private:
|
||||
|
@@ -47,17 +47,19 @@ nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) {
|
||||
"When create CVCUDA tensor from FD tensor,"
|
||||
"tensor shape should be 3-Dim, HWC layout");
|
||||
int batchsize = 1;
|
||||
int h = tensor.Shape()[0];
|
||||
int w = tensor.Shape()[1];
|
||||
int c = tensor.Shape()[2];
|
||||
|
||||
nvcv::TensorDataStridedCuda::Buffer buf;
|
||||
buf.strides[3] = FDDataTypeSize(tensor.Dtype());
|
||||
buf.strides[2] = tensor.shape[2] * buf.strides[3];
|
||||
buf.strides[1] = tensor.shape[1] * buf.strides[2];
|
||||
buf.strides[0] = tensor.shape[0] * buf.strides[1];
|
||||
buf.strides[2] = c * buf.strides[3];
|
||||
buf.strides[1] = w * buf.strides[2];
|
||||
buf.strides[0] = h * buf.strides[1];
|
||||
buf.basePtr = reinterpret_cast<NVCVByte*>(const_cast<void*>(tensor.Data()));
|
||||
|
||||
nvcv::Tensor::Requirements req = nvcv::Tensor::CalcRequirements(
|
||||
batchsize, {tensor.shape[1], tensor.shape[0]},
|
||||
CreateCvCudaImageFormat(tensor.Dtype(), tensor.shape[2]));
|
||||
batchsize, {w, h}, CreateCvCudaImageFormat(tensor.Dtype(), c));
|
||||
|
||||
nvcv::TensorDataStridedCuda tensor_data(
|
||||
nvcv::TensorShape{req.shape, req.rank, req.layout},
|
||||
@@ -70,6 +72,33 @@ void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor) {
|
||||
dynamic_cast<const nvcv::ITensorDataStridedCuda*>(tensor.exportData());
|
||||
return reinterpret_cast<void*>(data->basePtr());
|
||||
}
|
||||
|
||||
nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor) {
|
||||
FDASSERT(tensor.shape.size() == 3,
|
||||
"When create CVCUDA image from FD tensor,"
|
||||
"tensor shape should be 3-Dim, HWC layout");
|
||||
int h = tensor.Shape()[0];
|
||||
int w = tensor.Shape()[1];
|
||||
int c = tensor.Shape()[2];
|
||||
nvcv::ImageDataStridedCuda::Buffer buf;
|
||||
buf.numPlanes = 1;
|
||||
buf.planes[0].width = w;
|
||||
buf.planes[0].height = h;
|
||||
buf.planes[0].rowStride = w * c * FDDataTypeSize(tensor.Dtype());
|
||||
buf.planes[0].basePtr =
|
||||
reinterpret_cast<NVCVByte*>(const_cast<void*>(tensor.Data()));
|
||||
nvcv::ImageWrapData nvimg{nvcv::ImageDataStridedCuda{
|
||||
nvcv::ImageFormat{CreateCvCudaImageFormat(tensor.Dtype(), c)}, buf}};
|
||||
return nvimg;
|
||||
}
|
||||
|
||||
void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,
|
||||
nvcv::ImageBatchVarShape& img_batch) {
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
FDASSERT(tensors[i]->device == Device::GPU, "Tensor must on GPU.");
|
||||
img_batch.pushBack(CreateImageWrapData(*(tensors[i])));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -18,6 +18,7 @@
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
#include "nvcv/Tensor.hpp"
|
||||
#include <nvcv/ImageBatch.hpp>
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -25,7 +26,10 @@ namespace vision {
|
||||
nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel);
|
||||
nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor);
|
||||
void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor);
|
||||
nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor);
|
||||
void CreateCvCudaImageBatchVarShape(std::vector<FDTensor*>& tensors,
|
||||
nvcv::ImageBatchVarShape& img_batch);
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
#endif
|
||||
|
@@ -62,13 +62,24 @@ bool ProcessorManager::Run(std::vector<FDMat>* images,
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
if (CudaUsed()) {
|
||||
SetStream(&((*images)[i]));
|
||||
}
|
||||
if (images->size() > input_caches_.size()) {
|
||||
input_caches_.resize(images->size());
|
||||
output_caches_.resize(images->size());
|
||||
}
|
||||
|
||||
bool ret = Apply(images, outputs);
|
||||
FDMatBatch image_batch(images);
|
||||
image_batch.input_cache = &batch_input_cache_;
|
||||
image_batch.output_cache = &batch_output_cache_;
|
||||
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
if (CudaUsed()) {
|
||||
SetStream(&image_batch);
|
||||
}
|
||||
(*images)[i].input_cache = &input_caches_[i];
|
||||
(*images)[i].output_cache = &output_caches_[i];
|
||||
}
|
||||
|
||||
bool ret = Apply(&image_batch, outputs);
|
||||
|
||||
if (CudaUsed()) {
|
||||
SyncStream();
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
#include "fastdeploy/vision/common/processors/mat_batch.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -24,16 +25,28 @@ class FASTDEPLOY_DECL ProcessorManager {
|
||||
public:
|
||||
~ProcessorManager();
|
||||
|
||||
/** \brief Use CUDA to boost the performance of processors
|
||||
*
|
||||
* \param[in] enable_cv_cuda ture: use CV-CUDA, false: use CUDA only
|
||||
* \param[in] gpu_id GPU device id
|
||||
* \return true if the preprocess successed, otherwise false
|
||||
*/
|
||||
void UseCuda(bool enable_cv_cuda = false, int gpu_id = -1);
|
||||
|
||||
bool CudaUsed();
|
||||
|
||||
void SetStream(Mat* mat) {
|
||||
void SetStream(FDMat* mat) {
|
||||
#ifdef WITH_GPU
|
||||
mat->SetStream(stream_);
|
||||
#endif
|
||||
}
|
||||
|
||||
void SetStream(FDMatBatch* mat_batch) {
|
||||
#ifdef WITH_GPU
|
||||
mat_batch->SetStream(stream_);
|
||||
#endif
|
||||
}
|
||||
|
||||
void SyncStream() {
|
||||
#ifdef WITH_GPU
|
||||
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
|
||||
@@ -51,13 +64,13 @@ class FASTDEPLOY_DECL ProcessorManager {
|
||||
*/
|
||||
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
|
||||
|
||||
/** \brief The body of Run() function which needs to be implemented by a derived class
|
||||
/** \brief Apply() is the body of Run() function, it needs to be implemented by a derived class
|
||||
*
|
||||
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||
* \param[in] image_batch The input image batch
|
||||
* \param[in] outputs The output tensors which will feed in runtime
|
||||
* \return true if the preprocess successed, otherwise false
|
||||
*/
|
||||
virtual bool Apply(std::vector<FDMat>* images,
|
||||
virtual bool Apply(FDMatBatch* image_batch,
|
||||
std::vector<FDTensor>* outputs) = 0;
|
||||
|
||||
protected:
|
||||
@@ -68,6 +81,11 @@ class FASTDEPLOY_DECL ProcessorManager {
|
||||
cudaStream_t stream_ = nullptr;
|
||||
#endif
|
||||
int device_id_ = -1;
|
||||
|
||||
std::vector<FDTensor> input_caches_;
|
||||
std::vector<FDTensor> output_caches_;
|
||||
FDTensor batch_input_cache_;
|
||||
FDTensor batch_output_cache_;
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
|
41
fastdeploy/vision/common/processors/manager_pybind.cc
Normal file
41
fastdeploy/vision/common/processors/manager_pybind.cc
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "fastdeploy/pybind/main.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void BindProcessorManager(pybind11::module& m) {
|
||||
pybind11::class_<vision::ProcessorManager>(m, "ProcessorManager")
|
||||
.def("run",
|
||||
[](vision::ProcessorManager& self,
|
||||
std::vector<pybind11::array>& im_list) {
|
||||
std::vector<vision::FDMat> images;
|
||||
for (size_t i = 0; i < im_list.size(); ++i) {
|
||||
images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i])));
|
||||
}
|
||||
std::vector<FDTensor> outputs;
|
||||
if (!self.Run(&images, &outputs)) {
|
||||
throw std::runtime_error("Failed to process the input data");
|
||||
}
|
||||
if (!self.CudaUsed()) {
|
||||
for (size_t i = 0; i < outputs.size(); ++i) {
|
||||
outputs[i].StopSharing();
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
})
|
||||
.def("use_cuda",
|
||||
[](vision::ProcessorManager& self, bool enable_cv_cuda = false,
|
||||
int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); });
|
||||
}
|
||||
} // namespace fastdeploy
|
@@ -247,5 +247,40 @@ std::vector<FDMat> WrapMat(const std::vector<cv::Mat>& images) {
|
||||
return mats;
|
||||
}
|
||||
|
||||
bool CheckShapeConsistency(std::vector<Mat>* mats) {
|
||||
for (size_t i = 1; i < mats->size(); ++i) {
|
||||
if ((*mats)[i].Channels() != (*mats)[0].Channels() ||
|
||||
(*mats)[i].Width() != (*mats)[0].Width() ||
|
||||
(*mats)[i].Height() != (*mats)[0].Height()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
FDTensor* CreateCachedGpuInputTensor(Mat* mat) {
|
||||
#ifdef WITH_GPU
|
||||
FDTensor* src = mat->Tensor();
|
||||
if (src->device == Device::GPU) {
|
||||
return src;
|
||||
} else if (src->device == Device::CPU) {
|
||||
// Mats on CPU, we need copy these tensors from CPU to GPU
|
||||
FDASSERT(src->Shape().size() == 3, "The CPU tensor must has 3 dims.")
|
||||
mat->input_cache->Resize(src->Shape(), src->Dtype(), "input_cache",
|
||||
Device::GPU);
|
||||
FDASSERT(
|
||||
cudaMemcpyAsync(mat->input_cache->Data(), src->Data(), src->Nbytes(),
|
||||
cudaMemcpyHostToDevice, mat->Stream()) == 0,
|
||||
"[ERROR] Error occurs while copy memory from CPU to GPU.");
|
||||
return mat->input_cache;
|
||||
} else {
|
||||
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
|
||||
}
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
@@ -119,6 +119,11 @@ struct FASTDEPLOY_DECL Mat {
|
||||
void SetChannels(int s) { channels = s; }
|
||||
void SetWidth(int w) { width = w; }
|
||||
void SetHeight(int h) { height = h; }
|
||||
|
||||
// When using CV-CUDA/CUDA, please set input/output cache,
|
||||
// refer to manager.cc
|
||||
FDTensor* input_cache = nullptr;
|
||||
FDTensor* output_cache = nullptr;
|
||||
#ifdef WITH_GPU
|
||||
cudaStream_t Stream() const { return stream; }
|
||||
void SetStream(cudaStream_t s) { stream = s; }
|
||||
@@ -165,5 +170,12 @@ FASTDEPLOY_DECL FDMat WrapMat(const cv::Mat& image);
|
||||
*/
|
||||
FASTDEPLOY_DECL std::vector<FDMat> WrapMat(const std::vector<cv::Mat>& images);
|
||||
|
||||
bool CheckShapeConsistency(std::vector<Mat>* mats);
|
||||
|
||||
// Create an input tensor on GPU and save into input_cache.
|
||||
// If the Mat is on GPU, return the mat->Tensor() directly.
|
||||
// If the Mat is on CPU, then update the input cache tensor and copy the mat's
|
||||
// CPU tensor to this new GPU input cache tensor.
|
||||
FDTensor* CreateCachedGpuInputTensor(Mat* mat);
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
81
fastdeploy/vision/common/processors/mat_batch.cc
Normal file
81
fastdeploy/vision/common/processors/mat_batch.cc
Normal file
@@ -0,0 +1,81 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#include "fastdeploy/vision/common/processors/mat_batch.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
#ifdef WITH_GPU
|
||||
void FDMatBatch::SetStream(cudaStream_t s) {
|
||||
stream = s;
|
||||
for (size_t i = 0; i < mats->size(); ++i) {
|
||||
(*mats)[i].SetStream(s);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
FDTensor* FDMatBatch::Tensor() {
|
||||
if (has_batched_tensor) {
|
||||
return &fd_tensor;
|
||||
}
|
||||
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.")
|
||||
// Each mat has its own tensor,
|
||||
// to get a batched tensor, we need copy these tensors to a batched tensor
|
||||
FDTensor* src = (*mats)[0].Tensor();
|
||||
auto new_shape = src->Shape();
|
||||
new_shape.insert(new_shape.begin(), mats->size());
|
||||
input_cache->Resize(new_shape, src->Dtype(), "batch_input_cache", device);
|
||||
for (size_t i = 0; i < mats->size(); ++i) {
|
||||
FDASSERT(device == (*mats)[i].Tensor()->device,
|
||||
"Mats and MatBatch are not on the same device");
|
||||
uint8_t* p = reinterpret_cast<uint8_t*>(input_cache->Data());
|
||||
int num_bytes = (*mats)[i].Tensor()->Nbytes();
|
||||
FDTensor::CopyBuffer(p + i * num_bytes, (*mats)[i].Tensor()->Data(),
|
||||
num_bytes, device, false);
|
||||
}
|
||||
SetTensor(input_cache);
|
||||
return &fd_tensor;
|
||||
}
|
||||
|
||||
void FDMatBatch::SetTensor(FDTensor* tensor) {
|
||||
fd_tensor.SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
|
||||
tensor->device, tensor->device_id);
|
||||
has_batched_tensor = true;
|
||||
}
|
||||
|
||||
FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch) {
|
||||
#ifdef WITH_GPU
|
||||
auto mats = mat_batch->mats;
|
||||
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.")
|
||||
FDTensor* src = (*mats)[0].Tensor();
|
||||
if (mat_batch->device == Device::GPU) {
|
||||
return mat_batch->Tensor();
|
||||
} else if (mat_batch->device == Device::CPU) {
|
||||
// Mats on CPU, we need copy them to GPU and then get a batched GPU tensor
|
||||
for (size_t i = 0; i < mats->size(); ++i) {
|
||||
FDTensor* tensor = CreateCachedGpuInputTensor(&(*mats)[i]);
|
||||
(*mats)[i].SetTensor(tensor);
|
||||
}
|
||||
return mat_batch->Tensor();
|
||||
} else {
|
||||
FDASSERT(false, "FDMat is on unsupported device: %d", src->device);
|
||||
}
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
|
||||
#endif
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
76
fastdeploy/vision/common/processors/mat_batch.h
Normal file
76
fastdeploy/vision/common/processors/mat_batch.h
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
|
||||
#ifdef WITH_GPU
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
enum FDMatBatchLayout { NHWC, NCHW };
|
||||
|
||||
struct FASTDEPLOY_DECL FDMatBatch {
|
||||
FDMatBatch() = default;
|
||||
|
||||
// MatBatch is intialized with a list of mats,
|
||||
// the data is stored in the mats separately.
|
||||
// Call Tensor() function to get a batched 4-dimension tensor.
|
||||
explicit FDMatBatch(std::vector<Mat>* _mats) {
|
||||
mats = _mats;
|
||||
layout = FDMatBatchLayout::NHWC;
|
||||
mat_type = ProcLib::OPENCV;
|
||||
}
|
||||
|
||||
// Get the batched 4-dimension tensor.
|
||||
FDTensor* Tensor();
|
||||
|
||||
void SetTensor(FDTensor* tensor);
|
||||
|
||||
private:
|
||||
#ifdef WITH_GPU
|
||||
cudaStream_t stream = nullptr;
|
||||
#endif
|
||||
FDTensor fd_tensor;
|
||||
|
||||
public:
|
||||
// When using CV-CUDA/CUDA, please set input/output cache,
|
||||
// refer to manager.cc
|
||||
FDTensor* input_cache;
|
||||
FDTensor* output_cache;
|
||||
#ifdef WITH_GPU
|
||||
cudaStream_t Stream() const { return stream; }
|
||||
void SetStream(cudaStream_t s);
|
||||
#endif
|
||||
|
||||
std::vector<FDMat>* mats;
|
||||
ProcLib mat_type = ProcLib::OPENCV;
|
||||
FDMatBatchLayout layout = FDMatBatchLayout::NHWC;
|
||||
Device device = Device::CPU;
|
||||
|
||||
// False: the data is stored in the mats separately
|
||||
// True: the data is stored in the fd_tensor continuously in 4 dimensions
|
||||
bool has_batched_tensor = false;
|
||||
};
|
||||
|
||||
// Create a batched input tensor on GPU and save into input_cache.
|
||||
// If the MatBatch is on GPU, return the Tensor() directly.
|
||||
// If the MatBatch is on CPU, then copy the CPU tensors to GPU and get a GPU
|
||||
// batched input tensor.
|
||||
FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch);
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -56,7 +56,7 @@ NormalizeAndPermute::NormalizeAndPermute(const std::vector<float>& mean,
|
||||
swap_rb_ = swap_rb;
|
||||
}
|
||||
|
||||
bool NormalizeAndPermute::ImplByOpenCV(Mat* mat) {
|
||||
bool NormalizeAndPermute::ImplByOpenCV(FDMat* mat) {
|
||||
cv::Mat* im = mat->GetOpenCVMat();
|
||||
int origin_w = im->cols;
|
||||
int origin_h = im->rows;
|
||||
@@ -79,7 +79,7 @@ bool NormalizeAndPermute::ImplByOpenCV(Mat* mat) {
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool NormalizeAndPermute::ImplByFlyCV(Mat* mat) {
|
||||
bool NormalizeAndPermute::ImplByFlyCV(FDMat* mat) {
|
||||
if (mat->layout != Layout::HWC) {
|
||||
FDERROR << "Only supports input with HWC layout." << std::endl;
|
||||
return false;
|
||||
@@ -109,7 +109,7 @@ bool NormalizeAndPermute::ImplByFlyCV(Mat* mat) {
|
||||
}
|
||||
#endif
|
||||
|
||||
bool NormalizeAndPermute::Run(Mat* mat, const std::vector<float>& mean,
|
||||
bool NormalizeAndPermute::Run(FDMat* mat, const std::vector<float>& mean,
|
||||
const std::vector<float>& std, bool is_scale,
|
||||
const std::vector<float>& min,
|
||||
const std::vector<float>& max, ProcLib lib,
|
||||
|
@@ -18,63 +18,110 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
__global__ void NormalizeAndPermuteKernel(uint8_t* src, float* dst,
|
||||
__global__ void NormalizeAndPermuteKernel(const uint8_t* src, float* dst,
|
||||
const float* alpha, const float* beta,
|
||||
int num_channel, bool swap_rb,
|
||||
int edge) {
|
||||
int batch_size, int edge) {
|
||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (idx >= edge) return;
|
||||
|
||||
if (swap_rb) {
|
||||
uint8_t tmp = src[num_channel * idx];
|
||||
src[num_channel * idx] = src[num_channel * idx + 2];
|
||||
src[num_channel * idx + 2] = tmp;
|
||||
}
|
||||
int img_size = edge / batch_size;
|
||||
int n = idx / img_size; // batch index
|
||||
int p = idx - (n * img_size); // pixel index within the image
|
||||
|
||||
for (int i = 0; i < num_channel; ++i) {
|
||||
dst[idx + edge * i] = src[num_channel * idx + i] * alpha[i] + beta[i];
|
||||
int j = i;
|
||||
if (swap_rb) {
|
||||
j = 2 - i;
|
||||
}
|
||||
dst[n * img_size * num_channel + i * img_size + p] =
|
||||
src[num_channel * idx + j] * alpha[i] + beta[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool NormalizeAndPermute::ImplByCuda(Mat* mat) {
|
||||
bool NormalizeAndPermute::ImplByCuda(FDMat* mat) {
|
||||
// Prepare input tensor
|
||||
std::string tensor_name = Name() + "_cvcuda_src";
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat);
|
||||
|
||||
// Prepare output tensor
|
||||
tensor_name = Name() + "_dst";
|
||||
FDTensor* dst = UpdateAndGetCachedTensor(src->Shape(), FDDataType::FP32,
|
||||
tensor_name, Device::GPU);
|
||||
mat->output_cache->Resize(src->Shape(), FDDataType::FP32, "output_cache",
|
||||
Device::GPU);
|
||||
|
||||
// Copy alpha and beta to GPU
|
||||
tensor_name = Name() + "_alpha";
|
||||
FDMat alpha_mat =
|
||||
FDMat::Create(1, 1, alpha_.size(), FDDataType::FP32, alpha_.data());
|
||||
FDTensor* alpha = CreateCachedGpuInputTensor(&alpha_mat, tensor_name);
|
||||
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,
|
||||
"alpha", Device::GPU);
|
||||
cudaMemcpy(gpu_alpha_.Data(), alpha_.data(), gpu_alpha_.Nbytes(),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
tensor_name = Name() + "_beta";
|
||||
FDMat beta_mat =
|
||||
FDMat::Create(1, 1, beta_.size(), FDDataType::FP32, beta_.data());
|
||||
FDTensor* beta = CreateCachedGpuInputTensor(&beta_mat, tensor_name);
|
||||
gpu_beta_.Resize({1, 1, static_cast<int>(beta_.size())}, FDDataType::FP32,
|
||||
"beta", Device::GPU);
|
||||
cudaMemcpy(gpu_beta_.Data(), beta_.data(), gpu_beta_.Nbytes(),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
int jobs = mat->Width() * mat->Height();
|
||||
int jobs = 1 * mat->Width() * mat->Height();
|
||||
int threads = 256;
|
||||
int blocks = ceil(jobs / (float)threads);
|
||||
NormalizeAndPermuteKernel<<<blocks, threads, 0, mat->Stream()>>>(
|
||||
reinterpret_cast<uint8_t*>(src->Data()),
|
||||
reinterpret_cast<float*>(dst->Data()),
|
||||
reinterpret_cast<float*>(alpha->Data()),
|
||||
reinterpret_cast<float*>(beta->Data()), mat->Channels(), swap_rb_, jobs);
|
||||
reinterpret_cast<float*>(mat->output_cache->Data()),
|
||||
reinterpret_cast<float*>(gpu_alpha_.Data()),
|
||||
reinterpret_cast<float*>(gpu_beta_.Data()), mat->Channels(), swap_rb_, 1,
|
||||
jobs);
|
||||
|
||||
mat->SetTensor(dst);
|
||||
mat->SetTensor(mat->output_cache);
|
||||
mat->device = Device::GPU;
|
||||
mat->layout = Layout::CHW;
|
||||
mat->mat_type = ProcLib::CUDA;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) {
|
||||
// Prepare input tensor
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat_batch);
|
||||
|
||||
// Prepare output tensor
|
||||
mat_batch->output_cache->Resize(src->Shape(), FDDataType::FP32,
|
||||
"output_cache", Device::GPU);
|
||||
// NHWC -> NCHW
|
||||
std::swap(mat_batch->output_cache->shape[1],
|
||||
mat_batch->output_cache->shape[3]);
|
||||
|
||||
// Copy alpha and beta to GPU
|
||||
gpu_alpha_.Resize({1, 1, static_cast<int>(alpha_.size())}, FDDataType::FP32,
|
||||
"alpha", Device::GPU);
|
||||
cudaMemcpy(gpu_alpha_.Data(), alpha_.data(), gpu_alpha_.Nbytes(),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
gpu_beta_.Resize({1, 1, static_cast<int>(beta_.size())}, FDDataType::FP32,
|
||||
"beta", Device::GPU);
|
||||
cudaMemcpy(gpu_beta_.Data(), beta_.data(), gpu_beta_.Nbytes(),
|
||||
cudaMemcpyHostToDevice);
|
||||
|
||||
int jobs =
|
||||
mat_batch->output_cache->Numel() / mat_batch->output_cache->shape[1];
|
||||
int threads = 256;
|
||||
int blocks = ceil(jobs / (float)threads);
|
||||
NormalizeAndPermuteKernel<<<blocks, threads, 0, mat_batch->Stream()>>>(
|
||||
reinterpret_cast<uint8_t*>(src->Data()),
|
||||
reinterpret_cast<float*>(mat_batch->output_cache->Data()),
|
||||
reinterpret_cast<float*>(gpu_alpha_.Data()),
|
||||
reinterpret_cast<float*>(gpu_beta_.Data()),
|
||||
mat_batch->output_cache->shape[1], swap_rb_,
|
||||
mat_batch->output_cache->shape[0], jobs);
|
||||
|
||||
mat_batch->SetTensor(mat_batch->output_cache);
|
||||
mat_batch->device = Device::GPU;
|
||||
mat_batch->layout = FDMatBatchLayout::NCHW;
|
||||
mat_batch->mat_type = ProcLib::CUDA;
|
||||
return true;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool NormalizeAndPermute::ImplByCvCuda(Mat* mat) { return ImplByCuda(mat); }
|
||||
bool NormalizeAndPermute::ImplByCvCuda(FDMat* mat) { return ImplByCuda(mat); }
|
||||
|
||||
bool NormalizeAndPermute::ImplByCvCuda(FDMatBatch* mat_batch) {
|
||||
return ImplByCuda(mat_batch);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -25,15 +25,17 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
|
||||
const std::vector<float>& min = std::vector<float>(),
|
||||
const std::vector<float>& max = std::vector<float>(),
|
||||
bool swap_rb = false);
|
||||
bool ImplByOpenCV(Mat* mat);
|
||||
bool ImplByOpenCV(FDMat* mat);
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool ImplByFlyCV(Mat* mat);
|
||||
bool ImplByFlyCV(FDMat* mat);
|
||||
#endif
|
||||
#ifdef WITH_GPU
|
||||
bool ImplByCuda(Mat* mat);
|
||||
bool ImplByCuda(FDMat* mat);
|
||||
bool ImplByCuda(FDMatBatch* mat_batch);
|
||||
#endif
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ImplByCvCuda(Mat* mat);
|
||||
bool ImplByCvCuda(FDMat* mat);
|
||||
bool ImplByCvCuda(FDMatBatch* mat_batch);
|
||||
#endif
|
||||
std::string Name() { return "NormalizeAndPermute"; }
|
||||
|
||||
@@ -47,7 +49,7 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
|
||||
// There will be some precomputation in contruct function
|
||||
// and the `norm(mat)` only need to compute result = mat * alpha + beta
|
||||
// which will reduce lots of time
|
||||
static bool Run(Mat* mat, const std::vector<float>& mean,
|
||||
static bool Run(FDMat* mat, const std::vector<float>& mean,
|
||||
const std::vector<float>& std, bool is_scale = true,
|
||||
const std::vector<float>& min = std::vector<float>(),
|
||||
const std::vector<float>& max = std::vector<float>(),
|
||||
@@ -76,6 +78,8 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
|
||||
private:
|
||||
std::vector<float> alpha_;
|
||||
std::vector<float> beta_;
|
||||
FDTensor gpu_alpha_;
|
||||
FDTensor gpu_beta_;
|
||||
bool swap_rb_;
|
||||
};
|
||||
} // namespace vision
|
||||
|
@@ -23,7 +23,7 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
bool Resize::ImplByOpenCV(Mat* mat) {
|
||||
bool Resize::ImplByOpenCV(FDMat* mat) {
|
||||
if (mat->layout != Layout::HWC) {
|
||||
FDERROR << "Resize: The format of input is not HWC." << std::endl;
|
||||
return false;
|
||||
@@ -61,7 +61,7 @@ bool Resize::ImplByOpenCV(Mat* mat) {
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool Resize::ImplByFlyCV(Mat* mat) {
|
||||
bool Resize::ImplByFlyCV(FDMat* mat) {
|
||||
if (mat->layout != Layout::HWC) {
|
||||
FDERROR << "Resize: The format of input is not HWC." << std::endl;
|
||||
return false;
|
||||
@@ -123,7 +123,7 @@ bool Resize::ImplByFlyCV(Mat* mat) {
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool Resize::ImplByCvCuda(Mat* mat) {
|
||||
bool Resize::ImplByCvCuda(FDMat* mat) {
|
||||
if (width_ == mat->Width() && height_ == mat->Height()) {
|
||||
return true;
|
||||
}
|
||||
@@ -143,23 +143,20 @@ bool Resize::ImplByCvCuda(Mat* mat) {
|
||||
}
|
||||
|
||||
// Prepare input tensor
|
||||
std::string tensor_name = Name() + "_cvcuda_src";
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat);
|
||||
auto src_tensor = CreateCvCudaTensorWrapData(*src);
|
||||
|
||||
// Prepare output tensor
|
||||
tensor_name = Name() + "_cvcuda_dst";
|
||||
FDTensor* dst =
|
||||
UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, mat->Type(),
|
||||
tensor_name, Device::GPU);
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
|
||||
mat->output_cache->Resize({height_, width_, mat->Channels()}, mat->Type(),
|
||||
"output_cache", Device::GPU);
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
|
||||
|
||||
// CV-CUDA Interp value is compatible with OpenCV
|
||||
cvcuda::Resize resize_op;
|
||||
resize_op(mat->Stream(), src_tensor, dst_tensor,
|
||||
NVCVInterpolationType(interp_));
|
||||
|
||||
mat->SetTensor(dst);
|
||||
mat->SetTensor(mat->output_cache);
|
||||
mat->SetWidth(width_);
|
||||
mat->SetHeight(height_);
|
||||
mat->device = Device::GPU;
|
||||
@@ -168,8 +165,8 @@ bool Resize::ImplByCvCuda(Mat* mat) {
|
||||
}
|
||||
#endif
|
||||
|
||||
bool Resize::Run(Mat* mat, int width, int height, float scale_w, float scale_h,
|
||||
int interp, bool use_scale, ProcLib lib) {
|
||||
bool Resize::Run(FDMat* mat, int width, int height, float scale_w,
|
||||
float scale_h, int interp, bool use_scale, ProcLib lib) {
|
||||
if (mat->Height() == height && mat->Width() == width) {
|
||||
return true;
|
||||
}
|
||||
|
@@ -31,16 +31,16 @@ class FASTDEPLOY_DECL Resize : public Processor {
|
||||
use_scale_ = use_scale;
|
||||
}
|
||||
|
||||
bool ImplByOpenCV(Mat* mat);
|
||||
bool ImplByOpenCV(FDMat* mat);
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool ImplByFlyCV(Mat* mat);
|
||||
bool ImplByFlyCV(FDMat* mat);
|
||||
#endif
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ImplByCvCuda(Mat* mat);
|
||||
bool ImplByCvCuda(FDMat* mat);
|
||||
#endif
|
||||
std::string Name() { return "Resize"; }
|
||||
|
||||
static bool Run(Mat* mat, int width, int height, float scale_w = -1.0,
|
||||
static bool Run(FDMat* mat, int width, int height, float scale_w = -1.0,
|
||||
float scale_h = -1.0, int interp = 1, bool use_scale = false,
|
||||
ProcLib lib = ProcLib::DEFAULT);
|
||||
|
||||
|
@@ -23,7 +23,7 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
bool ResizeByShort::ImplByOpenCV(Mat* mat) {
|
||||
bool ResizeByShort::ImplByOpenCV(FDMat* mat) {
|
||||
cv::Mat* im = mat->GetOpenCVMat();
|
||||
int origin_w = im->cols;
|
||||
int origin_h = im->rows;
|
||||
@@ -43,7 +43,7 @@ bool ResizeByShort::ImplByOpenCV(Mat* mat) {
|
||||
}
|
||||
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool ResizeByShort::ImplByFlyCV(Mat* mat) {
|
||||
bool ResizeByShort::ImplByFlyCV(FDMat* mat) {
|
||||
fcv::Mat* im = mat->GetFlyCVMat();
|
||||
int origin_w = im->width();
|
||||
int origin_h = im->height();
|
||||
@@ -87,10 +87,9 @@ bool ResizeByShort::ImplByFlyCV(Mat* mat) {
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ResizeByShort::ImplByCvCuda(Mat* mat) {
|
||||
bool ResizeByShort::ImplByCvCuda(FDMat* mat) {
|
||||
// Prepare input tensor
|
||||
std::string tensor_name = Name() + "_cvcuda_src";
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name);
|
||||
FDTensor* src = CreateCachedGpuInputTensor(mat);
|
||||
auto src_tensor = CreateCvCudaTensorWrapData(*src);
|
||||
|
||||
double scale = GenerateScale(mat->Width(), mat->Height());
|
||||
@@ -98,23 +97,69 @@ bool ResizeByShort::ImplByCvCuda(Mat* mat) {
|
||||
int height = static_cast<int>(round(scale * mat->Height()));
|
||||
|
||||
// Prepare output tensor
|
||||
tensor_name = Name() + "_cvcuda_dst";
|
||||
FDTensor* dst = UpdateAndGetCachedTensor(
|
||||
{height, width, mat->Channels()}, mat->Type(), tensor_name, Device::GPU);
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*dst);
|
||||
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
|
||||
"output_cache", Device::GPU);
|
||||
auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache));
|
||||
|
||||
// CV-CUDA Interp value is compatible with OpenCV
|
||||
cvcuda::Resize resize_op;
|
||||
resize_op(mat->Stream(), src_tensor, dst_tensor,
|
||||
NVCVInterpolationType(interp_));
|
||||
|
||||
mat->SetTensor(dst);
|
||||
mat->SetTensor(mat->output_cache);
|
||||
mat->SetWidth(width);
|
||||
mat->SetHeight(height);
|
||||
mat->device = Device::GPU;
|
||||
mat->mat_type = ProcLib::CVCUDA;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ResizeByShort::ImplByCvCuda(FDMatBatch* mat_batch) {
|
||||
// TODO(wangxinyu): to support batched tensor as input
|
||||
FDASSERT(mat_batch->has_batched_tensor == false,
|
||||
"ResizeByShort doesn't support batched tensor as input for now.");
|
||||
// Prepare input batch
|
||||
std::string tensor_name = Name() + "_cvcuda_src";
|
||||
std::vector<FDTensor*> src_tensors;
|
||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||
FDTensor* src = CreateCachedGpuInputTensor(&(*(mat_batch->mats))[i]);
|
||||
src_tensors.push_back(src);
|
||||
}
|
||||
nvcv::ImageBatchVarShape src_batch(mat_batch->mats->size());
|
||||
CreateCvCudaImageBatchVarShape(src_tensors, src_batch);
|
||||
|
||||
// Prepare output batch
|
||||
tensor_name = Name() + "_cvcuda_dst";
|
||||
std::vector<FDTensor*> dst_tensors;
|
||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||
FDMat* mat = &(*(mat_batch->mats))[i];
|
||||
double scale = GenerateScale(mat->Width(), mat->Height());
|
||||
int width = static_cast<int>(round(scale * mat->Width()));
|
||||
int height = static_cast<int>(round(scale * mat->Height()));
|
||||
mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(),
|
||||
"output_cache", Device::GPU);
|
||||
dst_tensors.push_back(mat->output_cache);
|
||||
}
|
||||
nvcv::ImageBatchVarShape dst_batch(mat_batch->mats->size());
|
||||
CreateCvCudaImageBatchVarShape(dst_tensors, dst_batch);
|
||||
|
||||
// CV-CUDA Interp value is compatible with OpenCV
|
||||
cvcuda::Resize resize_op;
|
||||
resize_op(mat_batch->Stream(), src_batch, dst_batch,
|
||||
NVCVInterpolationType(interp_));
|
||||
|
||||
for (size_t i = 0; i < mat_batch->mats->size(); ++i) {
|
||||
FDMat* mat = &(*(mat_batch->mats))[i];
|
||||
mat->SetTensor(dst_tensors[i]);
|
||||
mat->SetWidth(dst_tensors[i]->Shape()[1]);
|
||||
mat->SetHeight(dst_tensors[i]->Shape()[0]);
|
||||
mat->device = Device::GPU;
|
||||
mat->mat_type = ProcLib::CVCUDA;
|
||||
}
|
||||
mat_batch->device = Device::GPU;
|
||||
mat_batch->mat_type = ProcLib::CVCUDA;
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
double ResizeByShort::GenerateScale(const int origin_w, const int origin_h) {
|
||||
@@ -143,7 +188,7 @@ double ResizeByShort::GenerateScale(const int origin_w, const int origin_h) {
|
||||
return scale;
|
||||
}
|
||||
|
||||
bool ResizeByShort::Run(Mat* mat, int target_size, int interp, bool use_scale,
|
||||
bool ResizeByShort::Run(FDMat* mat, int target_size, int interp, bool use_scale,
|
||||
const std::vector<int>& max_hw, ProcLib lib) {
|
||||
auto r = ResizeByShort(target_size, interp, use_scale, max_hw);
|
||||
return r(mat, lib);
|
||||
|
@@ -28,16 +28,17 @@ class FASTDEPLOY_DECL ResizeByShort : public Processor {
|
||||
interp_ = interp;
|
||||
use_scale_ = use_scale;
|
||||
}
|
||||
bool ImplByOpenCV(Mat* mat);
|
||||
bool ImplByOpenCV(FDMat* mat);
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool ImplByFlyCV(Mat* mat);
|
||||
bool ImplByFlyCV(FDMat* mat);
|
||||
#endif
|
||||
#ifdef ENABLE_CVCUDA
|
||||
bool ImplByCvCuda(Mat* mat);
|
||||
bool ImplByCvCuda(FDMat* mat);
|
||||
bool ImplByCvCuda(FDMatBatch* mat_batch);
|
||||
#endif
|
||||
std::string Name() { return "ResizeByShort"; }
|
||||
|
||||
static bool Run(Mat* mat, int target_size, int interp = 1,
|
||||
static bool Run(FDMat* mat, int target_size, int interp = 1,
|
||||
bool use_scale = true,
|
||||
const std::vector<int>& max_hw = std::vector<int>(),
|
||||
ProcLib lib = ProcLib::DEFAULT);
|
||||
|
@@ -16,6 +16,7 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
void BindProcessorManager(pybind11::module& m);
|
||||
void BindDetection(pybind11::module& m);
|
||||
void BindClassification(pybind11::module& m);
|
||||
void BindSegmentation(pybind11::module& m);
|
||||
@@ -204,6 +205,7 @@ void BindVision(pybind11::module& m) {
|
||||
m.def("disable_flycv", &vision::DisableFlyCV,
|
||||
"Disable image preprocessing by FlyCV, change to use OpenCV.");
|
||||
|
||||
BindProcessorManager(m);
|
||||
BindDetection(m);
|
||||
BindClassification(m);
|
||||
BindSegmentation(m);
|
||||
|
@@ -16,44 +16,40 @@ from __future__ import absolute_import
|
||||
import logging
|
||||
from .... import FastDeployModel, ModelFormat
|
||||
from .... import c_lib_wrap as C
|
||||
from ...common import ProcessorManager
|
||||
|
||||
|
||||
class PaddleClasPreprocessor:
|
||||
class PaddleClasPreprocessor(ProcessorManager):
|
||||
def __init__(self, config_file):
|
||||
"""Create a preprocessor for PaddleClasModel from configuration file
|
||||
|
||||
:param config_file: (str)Path of configuration file, e.g resnet50/inference_cls.yaml
|
||||
"""
|
||||
self._preprocessor = C.vision.classification.PaddleClasPreprocessor(
|
||||
super(PaddleClasPreprocessor, self).__init__()
|
||||
self._manager = C.vision.classification.PaddleClasPreprocessor(
|
||||
config_file)
|
||||
|
||||
def run(self, input_ims):
|
||||
"""Preprocess input images for PaddleClasModel
|
||||
|
||||
:param: input_ims: (list of numpy.ndarray)The input image
|
||||
:return: list of FDTensor
|
||||
"""
|
||||
return self._preprocessor.run(input_ims)
|
||||
|
||||
def use_cuda(self, enable_cv_cuda=False, gpu_id=-1):
|
||||
"""Use CUDA preprocessors
|
||||
|
||||
:param: enable_cv_cuda: Whether to enable CV-CUDA
|
||||
:param: gpu_id: GPU device id
|
||||
"""
|
||||
return self._preprocessor.use_cuda(enable_cv_cuda, gpu_id)
|
||||
|
||||
def disable_normalize(self):
|
||||
"""
|
||||
This function will disable normalize in preprocessing step.
|
||||
"""
|
||||
self._preprocessor.disable_normalize()
|
||||
self._manager.disable_normalize()
|
||||
|
||||
def disable_permute(self):
|
||||
"""
|
||||
This function will disable hwc2chw in preprocessing step.
|
||||
"""
|
||||
self._preprocessor.disable_permute()
|
||||
self._manager.disable_permute()
|
||||
|
||||
def initial_resize_on_cpu(self, v):
|
||||
"""
|
||||
When the initial operator is Resize, and input image size is large,
|
||||
maybe it's better to run resize on CPU, because the HostToDevice memcpy
|
||||
is time consuming. Set this True to run the initial resize on CPU.
|
||||
|
||||
:param: v: True or False
|
||||
"""
|
||||
self._manager.initial_resize_on_cpu(v)
|
||||
|
||||
|
||||
class PaddleClasPostprocessor:
|
||||
|
16
python/fastdeploy/vision/common/__init__.py
Normal file
16
python/fastdeploy/vision/common/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import absolute_import
|
||||
|
||||
from .manager import ProcessorManager
|
36
python/fastdeploy/vision/common/manager.py
Normal file
36
python/fastdeploy/vision/common/manager.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
|
||||
class ProcessorManager:
|
||||
def __init__(self):
|
||||
self._manager = None
|
||||
|
||||
def run(self, input_ims):
|
||||
"""Process input image
|
||||
|
||||
:param: input_ims: (list of numpy.ndarray) The input images
|
||||
:return: list of FDTensor
|
||||
"""
|
||||
return self._manager.run(input_ims)
|
||||
|
||||
def use_cuda(self, enable_cv_cuda=False, gpu_id=-1):
|
||||
"""Use CUDA processors
|
||||
|
||||
:param: enable_cv_cuda: Ture: use CV-CUDA, False: use CUDA only
|
||||
:param: gpu_id: GPU device id
|
||||
"""
|
||||
return self._manager.use_cuda(enable_cv_cuda, gpu_id)
|
@@ -72,6 +72,7 @@ setup_configs["PADDLELITE_URL"] = os.getenv("PADDLELITE_URL", "OFF")
|
||||
setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF")
|
||||
setup_configs["ENABLE_ENCRYPTION"] = os.getenv("ENABLE_ENCRYPTION", "OFF")
|
||||
setup_configs["ENABLE_FLYCV"] = os.getenv("ENABLE_FLYCV", "OFF")
|
||||
setup_configs["ENABLE_CVCUDA"] = os.getenv("ENABLE_CVCUDA", "OFF")
|
||||
setup_configs["ENABLE_TEXT"] = os.getenv("ENABLE_TEXT", "OFF")
|
||||
setup_configs["ENABLE_BENCHMARK"] = os.getenv("ENABLE_BENCHMARK", "OFF")
|
||||
setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF")
|
||||
|
Reference in New Issue
Block a user