diff --git a/fastdeploy/core/fd_tensor.cc b/fastdeploy/core/fd_tensor.cc index 533e58fd8..8b111025d 100644 --- a/fastdeploy/core/fd_tensor.cc +++ b/fastdeploy/core/fd_tensor.cc @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. #include "fastdeploy/core/fd_tensor.h" -#include "fastdeploy/core/float16.h" -#include "fastdeploy/utils/utils.h" #include #include + +#include "fastdeploy/core/float16.h" +#include "fastdeploy/utils/utils.h" #ifdef WITH_GPU #include #endif @@ -142,6 +143,9 @@ void FDTensor::Resize(const std::vector& new_shape, const FDDataType& data_type, const std::string& tensor_name, const Device& new_device) { + if (device != new_device) { + FreeFn(); + } external_data_ptr = nullptr; name = tensor_name; device = new_device; @@ -269,9 +273,10 @@ bool FDTensor::ReallocFn(size_t nbytes) { } return buffer_ != nullptr; #else - FDASSERT(false, "The FastDeploy FDTensor allocator didn't compile under " - "-DWITH_GPU=ON," - "so this is an unexpected problem happend."); + FDASSERT(false, + "The FastDeploy FDTensor allocator didn't compile under " + "-DWITH_GPU=ON," + "so this is an unexpected problem happend."); #endif } else { if (is_pinned_memory) { @@ -285,9 +290,10 @@ bool FDTensor::ReallocFn(size_t nbytes) { } return buffer_ != nullptr; #else - FDASSERT(false, "The FastDeploy FDTensor allocator didn't compile under " - "-DWITH_GPU=ON," - "so this is an unexpected problem happend."); + FDASSERT(false, + "The FastDeploy FDTensor allocator didn't compile under " + "-DWITH_GPU=ON," + "so this is an unexpected problem happend."); #endif } buffer_ = realloc(buffer_, nbytes); @@ -296,8 +302,7 @@ bool FDTensor::ReallocFn(size_t nbytes) { } void FDTensor::FreeFn() { - if (external_data_ptr != nullptr) - external_data_ptr = nullptr; + if (external_data_ptr != nullptr) external_data_ptr = nullptr; if (buffer_ != nullptr) { if (device == Device::GPU) { #ifdef WITH_GPU @@ -381,13 +386,16 @@ FDTensor::FDTensor(const Scalar& scalar) { (reinterpret_cast(Data()))[0] = scalar.to(); break; default: - break; + break; } } FDTensor::FDTensor(const FDTensor& other) - : shape(other.shape), name(other.name), dtype(other.dtype), - device(other.device), external_data_ptr(other.external_data_ptr), + : shape(other.shape), + name(other.name), + dtype(other.dtype), + device(other.device), + external_data_ptr(other.external_data_ptr), device_id(other.device_id) { // Copy buffer if (other.buffer_ == nullptr) { @@ -401,9 +409,12 @@ FDTensor::FDTensor(const FDTensor& other) } FDTensor::FDTensor(FDTensor&& other) - : buffer_(other.buffer_), shape(std::move(other.shape)), - name(std::move(other.name)), dtype(other.dtype), - external_data_ptr(other.external_data_ptr), device(other.device), + : buffer_(other.buffer_), + shape(std::move(other.shape)), + name(std::move(other.name)), + dtype(other.dtype), + external_data_ptr(other.external_data_ptr), + device(other.device), device_id(other.device_id) { other.name = ""; // Note(zhoushunjie): Avoid double free. diff --git a/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc b/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc index 514f5ad9d..1d4a24adf 100644 --- a/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc +++ b/fastdeploy/vision/classification/ppcls/ppcls_pybind.cc @@ -15,33 +15,9 @@ namespace fastdeploy { void BindPaddleClas(pybind11::module& m) { - pybind11::class_( - m, "PaddleClasPreprocessor") + pybind11::class_(m, "PaddleClasPreprocessor") .def(pybind11::init()) - .def("run", - [](vision::classification::PaddleClasPreprocessor& self, - std::vector& im_list) { - std::vector images; - for (size_t i = 0; i < im_list.size(); ++i) { - images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); - } - std::vector outputs; - if (!self.Run(&images, &outputs)) { - throw std::runtime_error( - "Failed to preprocess the input data in " - "PaddleClasPreprocessor."); - } - if (!self.CudaUsed()) { - for (size_t i = 0; i < outputs.size(); ++i) { - outputs[i].StopSharing(); - } - } - return outputs; - }) - .def("use_cuda", - [](vision::classification::PaddleClasPreprocessor& self, - bool enable_cv_cuda = false, - int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); }) .def("disable_normalize", [](vision::classification::PaddleClasPreprocessor& self) { self.DisableNormalize(); @@ -49,6 +25,10 @@ void BindPaddleClas(pybind11::module& m) { .def("disable_permute", [](vision::classification::PaddleClasPreprocessor& self) { self.DisablePermute(); + }) + .def("initial_resize_on_cpu", + [](vision::classification::PaddleClasPreprocessor& self, bool v) { + self.InitialResizeOnCpu(v); }); pybind11::class_( diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.cc b/fastdeploy/vision/classification/ppcls/preprocessor.cc index 90d40e094..ef0da9ce5 100644 --- a/fastdeploy/vision/classification/ppcls/preprocessor.cc +++ b/fastdeploy/vision/classification/ppcls/preprocessor.cc @@ -100,32 +100,23 @@ void PaddleClasPreprocessor::DisablePermute() { } } -bool PaddleClasPreprocessor::Apply(std::vector* images, +bool PaddleClasPreprocessor::Apply(FDMatBatch* image_batch, std::vector* outputs) { - for (size_t i = 0; i < images->size(); ++i) { - for (size_t j = 0; j < processors_.size(); ++j) { - bool ret = false; - ret = (*(processors_[j].get()))(&((*images)[i])); - if (!ret) { - FDERROR << "Failed to processs image:" << i << " in " - << processors_[j]->Name() << "." << std::endl; - return false; - } + for (size_t j = 0; j < processors_.size(); ++j) { + ProcLib lib = ProcLib::DEFAULT; + if (initial_resize_on_cpu_ && j == 0 && + processors_[j]->Name().find("Resize") == 0) { + lib = ProcLib::OPENCV; + } + if (!(*(processors_[j].get()))(image_batch, lib)) { + FDERROR << "Failed to processs image in " << processors_[j]->Name() << "." + << std::endl; + return false; } } outputs->resize(1); - // Concat all the preprocessed data to a batch tensor - std::vector tensors(images->size()); - for (size_t i = 0; i < images->size(); ++i) { - (*images)[i].ShareWithTensor(&(tensors[i])); - tensors[i].ExpandDim(0); - } - if (tensors.size() == 1) { - (*outputs)[0] = std::move(tensors[0]); - } else { - function::Concat(tensors, &((*outputs)[0]), 0); - } + (*outputs)[0] = std::move(*(image_batch->Tensor())); (*outputs)[0].device_id = DeviceId(); return true; } diff --git a/fastdeploy/vision/classification/ppcls/preprocessor.h b/fastdeploy/vision/classification/ppcls/preprocessor.h index 2f2beaddb..fc347fc3d 100644 --- a/fastdeploy/vision/classification/ppcls/preprocessor.h +++ b/fastdeploy/vision/classification/ppcls/preprocessor.h @@ -33,11 +33,11 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager { /** \brief Process the input image and prepare input tensors for runtime * - * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] image_batch The input image batch * \param[in] outputs The output tensors which will feed in runtime * \return true if the preprocess successed, otherwise false */ - virtual bool Apply(std::vector* images, + virtual bool Apply(FDMatBatch* image_batch, std::vector* outputs); /// This function will disable normalize in preprocessing step. @@ -45,6 +45,14 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager { /// This function will disable hwc2chw in preprocessing step. void DisablePermute(); + /** \brief When the initial operator is Resize, and input image size is large, + * maybe it's better to run resize on CPU, because the HostToDevice memcpy + * is time consuming. Set this true to run the initial resize on CPU. + * + * \param[in] v ture or false + */ + void InitialResizeOnCpu(bool v) { initial_resize_on_cpu_ = v; } + private: bool BuildPreprocessPipelineFromConfig(); std::vector> processors_; @@ -54,6 +62,7 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor : public ProcessorManager { bool disable_normalize_ = false; // read config file std::string config_file_; + bool initial_resize_on_cpu_ = false; }; } // namespace classification diff --git a/fastdeploy/vision/common/processors/base.cc b/fastdeploy/vision/common/processors/base.cc index a47cfe378..9c4a0177e 100644 --- a/fastdeploy/vision/common/processors/base.cc +++ b/fastdeploy/vision/common/processors/base.cc @@ -20,7 +20,7 @@ namespace fastdeploy { namespace vision { -bool Processor::operator()(Mat* mat, ProcLib lib) { +bool Processor::operator()(FDMat* mat, ProcLib lib) { ProcLib target = lib; if (lib == ProcLib::DEFAULT) { target = DefaultProcLib::default_lib; @@ -52,39 +52,38 @@ bool Processor::operator()(Mat* mat, ProcLib lib) { return ImplByOpenCV(mat); } -FDTensor* Processor::UpdateAndGetCachedTensor( - const std::vector& new_shape, const FDDataType& data_type, - const std::string& tensor_name, const Device& new_device, - const bool& use_pinned_memory) { - if (cached_tensors_.count(tensor_name) == 0) { - cached_tensors_[tensor_name] = FDTensor(); - } - cached_tensors_[tensor_name].is_pinned_memory = use_pinned_memory; - cached_tensors_[tensor_name].Resize(new_shape, data_type, tensor_name, - new_device); - return &cached_tensors_[tensor_name]; -} - -FDTensor* Processor::CreateCachedGpuInputTensor( - Mat* mat, const std::string& tensor_name) { -#ifdef WITH_GPU - FDTensor* src = mat->Tensor(); - if (src->device == Device::GPU) { - return src; - } else if (src->device == Device::CPU) { - FDTensor* tensor = UpdateAndGetCachedTensor(src->Shape(), src->Dtype(), - tensor_name, Device::GPU); - FDASSERT(cudaMemcpyAsync(tensor->Data(), src->Data(), tensor->Nbytes(), - cudaMemcpyHostToDevice, mat->Stream()) == 0, - "[ERROR] Error occurs while copy memory from CPU to GPU."); - return tensor; - } else { - FDASSERT(false, "FDMat is on unsupported device: %d", src->device); +bool Processor::operator()(FDMatBatch* mat_batch, ProcLib lib) { + ProcLib target = lib; + if (lib == ProcLib::DEFAULT) { + target = DefaultProcLib::default_lib; } + if (target == ProcLib::FLYCV) { +#ifdef ENABLE_FLYCV + return ImplByFlyCV(mat_batch); #else - FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); + FDASSERT(false, "FastDeploy didn't compile with FlyCV."); #endif - return nullptr; + } else if (target == ProcLib::CUDA) { +#ifdef WITH_GPU + FDASSERT( + mat_batch->Stream() != nullptr, + "CUDA processor requires cuda stream, please set stream for mat_batch"); + return ImplByCuda(mat_batch); +#else + FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); +#endif + } else if (target == ProcLib::CVCUDA) { +#ifdef ENABLE_CVCUDA + FDASSERT(mat_batch->Stream() != nullptr, + "CV-CUDA processor requires cuda stream, please set stream for " + "mat_batch"); + return ImplByCvCuda(mat_batch); +#else + FDASSERT(false, "FastDeploy didn't compile with CV-CUDA."); +#endif + } + // DEFAULT & OPENCV + return ImplByOpenCV(mat_batch); } void EnableFlyCV() { diff --git a/fastdeploy/vision/common/processors/base.h b/fastdeploy/vision/common/processors/base.h index 6fb3a33eb..786e88672 100644 --- a/fastdeploy/vision/common/processors/base.h +++ b/fastdeploy/vision/common/processors/base.h @@ -16,6 +16,7 @@ #include "fastdeploy/utils/utils.h" #include "fastdeploy/vision/common/processors/mat.h" +#include "fastdeploy/vision/common/processors/mat_batch.h" #include "opencv2/highgui/highgui.hpp" #include "opencv2/imgproc/imgproc.hpp" #include @@ -46,46 +47,63 @@ class FASTDEPLOY_DECL Processor { virtual std::string Name() = 0; - virtual bool ImplByOpenCV(Mat* mat) { + virtual bool ImplByOpenCV(FDMat* mat) { FDERROR << Name() << " Not Implement Yet." << std::endl; return false; } - virtual bool ImplByFlyCV(Mat* mat) { + virtual bool ImplByOpenCV(FDMatBatch* mat_batch) { + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + if (ImplByOpenCV(&(*(mat_batch->mats))[i]) != true) { + return false; + } + } + return true; + } + + virtual bool ImplByFlyCV(FDMat* mat) { return ImplByOpenCV(mat); } - virtual bool ImplByCuda(Mat* mat) { + virtual bool ImplByFlyCV(FDMatBatch* mat_batch) { + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + if (ImplByFlyCV(&(*(mat_batch->mats))[i]) != true) { + return false; + } + } + return true; + } + + virtual bool ImplByCuda(FDMat* mat) { return ImplByOpenCV(mat); } - virtual bool ImplByCvCuda(Mat* mat) { + virtual bool ImplByCuda(FDMatBatch* mat_batch) { + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + if (ImplByCuda(&(*(mat_batch->mats))[i]) != true) { + return false; + } + } + return true; + } + + virtual bool ImplByCvCuda(FDMat* mat) { return ImplByOpenCV(mat); } - virtual bool operator()(Mat* mat, ProcLib lib = ProcLib::DEFAULT); + virtual bool ImplByCvCuda(FDMatBatch* mat_batch) { + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + if (ImplByCvCuda(&(*(mat_batch->mats))[i]) != true) { + return false; + } + } + return true; + } - protected: - // Update and get the cached tensor from the cached_tensors_ map. - // The tensor is indexed by a string. - // If the tensor doesn't exists in the map, then create a new tensor. - // If the tensor exists and shape is getting larger, then realloc the buffer. - // If the tensor exists and shape is not getting larger, then return the - // cached tensor directly. - FDTensor* UpdateAndGetCachedTensor( - const std::vector& new_shape, const FDDataType& data_type, - const std::string& tensor_name, const Device& new_device = Device::CPU, - const bool& use_pinned_memory = false); + virtual bool operator()(FDMat* mat, ProcLib lib = ProcLib::DEFAULT); - // Create an input tensor on GPU and save into cached_tensors_. - // If the Mat is on GPU, return the mat->Tensor() directly. - // If the Mat is on CPU, then create a cached GPU tensor and copy the mat's - // CPU tensor to this new GPU tensor. - FDTensor* CreateCachedGpuInputTensor(Mat* mat, - const std::string& tensor_name); - - private: - std::unordered_map cached_tensors_; + virtual bool operator()(FDMatBatch* mat_batch, + ProcLib lib = ProcLib::DEFAULT); }; } // namespace vision diff --git a/fastdeploy/vision/common/processors/center_crop.cc b/fastdeploy/vision/common/processors/center_crop.cc index bb0c96947..1857f7a81 100644 --- a/fastdeploy/vision/common/processors/center_crop.cc +++ b/fastdeploy/vision/common/processors/center_crop.cc @@ -23,7 +23,7 @@ namespace fastdeploy { namespace vision { -bool CenterCrop::ImplByOpenCV(Mat* mat) { +bool CenterCrop::ImplByOpenCV(FDMat* mat) { cv::Mat* im = mat->GetOpenCVMat(); int height = static_cast(im->rows); int width = static_cast(im->cols); @@ -42,7 +42,7 @@ bool CenterCrop::ImplByOpenCV(Mat* mat) { } #ifdef ENABLE_FLYCV -bool CenterCrop::ImplByFlyCV(Mat* mat) { +bool CenterCrop::ImplByFlyCV(FDMat* mat) { fcv::Mat* im = mat->GetFlyCVMat(); int height = static_cast(im->height()); int width = static_cast(im->width()); @@ -63,18 +63,15 @@ bool CenterCrop::ImplByFlyCV(Mat* mat) { #endif #ifdef ENABLE_CVCUDA -bool CenterCrop::ImplByCvCuda(Mat* mat) { +bool CenterCrop::ImplByCvCuda(FDMat* mat) { // Prepare input tensor - std::string tensor_name = Name() + "_cvcuda_src"; - FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name); + FDTensor* src = CreateCachedGpuInputTensor(mat); auto src_tensor = CreateCvCudaTensorWrapData(*src); // Prepare output tensor - tensor_name = Name() + "_cvcuda_dst"; - FDTensor* dst = - UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, src->Dtype(), - tensor_name, Device::GPU); - auto dst_tensor = CreateCvCudaTensorWrapData(*dst); + mat->output_cache->Resize({height_, width_, mat->Channels()}, src->Dtype(), + "output_cache", Device::GPU); + auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); int offset_x = static_cast((mat->Width() - width_) / 2); int offset_y = static_cast((mat->Height() - height_) / 2); @@ -82,16 +79,27 @@ bool CenterCrop::ImplByCvCuda(Mat* mat) { NVCVRectI crop_roi = {offset_x, offset_y, width_, height_}; crop_op(mat->Stream(), src_tensor, dst_tensor, crop_roi); - mat->SetTensor(dst); + mat->SetTensor(mat->output_cache); mat->SetWidth(width_); mat->SetHeight(height_); mat->device = Device::GPU; mat->mat_type = ProcLib::CVCUDA; return true; } + +bool CenterCrop::ImplByCvCuda(FDMatBatch* mat_batch) { + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + if (ImplByCvCuda(&((*(mat_batch->mats))[i])) != true) { + return false; + } + } + mat_batch->device = Device::GPU; + mat_batch->mat_type = ProcLib::CVCUDA; + return true; +} #endif -bool CenterCrop::Run(Mat* mat, const int& width, const int& height, +bool CenterCrop::Run(FDMat* mat, const int& width, const int& height, ProcLib lib) { auto c = CenterCrop(width, height); return c(mat, lib); diff --git a/fastdeploy/vision/common/processors/center_crop.h b/fastdeploy/vision/common/processors/center_crop.h index 7455773f6..3ca3a7391 100644 --- a/fastdeploy/vision/common/processors/center_crop.h +++ b/fastdeploy/vision/common/processors/center_crop.h @@ -22,16 +22,17 @@ namespace vision { class FASTDEPLOY_DECL CenterCrop : public Processor { public: CenterCrop(int width, int height) : height_(height), width_(width) {} - bool ImplByOpenCV(Mat* mat); + bool ImplByOpenCV(FDMat* mat); #ifdef ENABLE_FLYCV - bool ImplByFlyCV(Mat* mat); + bool ImplByFlyCV(FDMat* mat); #endif #ifdef ENABLE_CVCUDA - bool ImplByCvCuda(Mat* mat); + bool ImplByCvCuda(FDMat* mat); + bool ImplByCvCuda(FDMatBatch* mat_batch); #endif std::string Name() { return "CenterCrop"; } - static bool Run(Mat* mat, const int& width, const int& height, + static bool Run(FDMat* mat, const int& width, const int& height, ProcLib lib = ProcLib::DEFAULT); private: diff --git a/fastdeploy/vision/common/processors/cvcuda_utils.cc b/fastdeploy/vision/common/processors/cvcuda_utils.cc index 482d0dc69..c7d25361b 100644 --- a/fastdeploy/vision/common/processors/cvcuda_utils.cc +++ b/fastdeploy/vision/common/processors/cvcuda_utils.cc @@ -47,17 +47,19 @@ nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor) { "When create CVCUDA tensor from FD tensor," "tensor shape should be 3-Dim, HWC layout"); int batchsize = 1; + int h = tensor.Shape()[0]; + int w = tensor.Shape()[1]; + int c = tensor.Shape()[2]; nvcv::TensorDataStridedCuda::Buffer buf; buf.strides[3] = FDDataTypeSize(tensor.Dtype()); - buf.strides[2] = tensor.shape[2] * buf.strides[3]; - buf.strides[1] = tensor.shape[1] * buf.strides[2]; - buf.strides[0] = tensor.shape[0] * buf.strides[1]; + buf.strides[2] = c * buf.strides[3]; + buf.strides[1] = w * buf.strides[2]; + buf.strides[0] = h * buf.strides[1]; buf.basePtr = reinterpret_cast(const_cast(tensor.Data())); nvcv::Tensor::Requirements req = nvcv::Tensor::CalcRequirements( - batchsize, {tensor.shape[1], tensor.shape[0]}, - CreateCvCudaImageFormat(tensor.Dtype(), tensor.shape[2])); + batchsize, {w, h}, CreateCvCudaImageFormat(tensor.Dtype(), c)); nvcv::TensorDataStridedCuda tensor_data( nvcv::TensorShape{req.shape, req.rank, req.layout}, @@ -70,6 +72,33 @@ void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor) { dynamic_cast(tensor.exportData()); return reinterpret_cast(data->basePtr()); } + +nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor) { + FDASSERT(tensor.shape.size() == 3, + "When create CVCUDA image from FD tensor," + "tensor shape should be 3-Dim, HWC layout"); + int h = tensor.Shape()[0]; + int w = tensor.Shape()[1]; + int c = tensor.Shape()[2]; + nvcv::ImageDataStridedCuda::Buffer buf; + buf.numPlanes = 1; + buf.planes[0].width = w; + buf.planes[0].height = h; + buf.planes[0].rowStride = w * c * FDDataTypeSize(tensor.Dtype()); + buf.planes[0].basePtr = + reinterpret_cast(const_cast(tensor.Data())); + nvcv::ImageWrapData nvimg{nvcv::ImageDataStridedCuda{ + nvcv::ImageFormat{CreateCvCudaImageFormat(tensor.Dtype(), c)}, buf}}; + return nvimg; +} + +void CreateCvCudaImageBatchVarShape(std::vector& tensors, + nvcv::ImageBatchVarShape& img_batch) { + for (size_t i = 0; i < tensors.size(); ++i) { + FDASSERT(tensors[i]->device == Device::GPU, "Tensor must on GPU."); + img_batch.pushBack(CreateImageWrapData(*(tensors[i]))); + } +} #endif } // namespace vision diff --git a/fastdeploy/vision/common/processors/cvcuda_utils.h b/fastdeploy/vision/common/processors/cvcuda_utils.h index cd4eae8f6..60971ec49 100644 --- a/fastdeploy/vision/common/processors/cvcuda_utils.h +++ b/fastdeploy/vision/common/processors/cvcuda_utils.h @@ -18,6 +18,7 @@ #ifdef ENABLE_CVCUDA #include "nvcv/Tensor.hpp" +#include namespace fastdeploy { namespace vision { @@ -25,7 +26,10 @@ namespace vision { nvcv::ImageFormat CreateCvCudaImageFormat(FDDataType type, int channel); nvcv::TensorWrapData CreateCvCudaTensorWrapData(const FDTensor& tensor); void* GetCvCudaTensorDataPtr(const nvcv::TensorWrapData& tensor); +nvcv::ImageWrapData CreateImageWrapData(const FDTensor& tensor); +void CreateCvCudaImageBatchVarShape(std::vector& tensors, + nvcv::ImageBatchVarShape& img_batch); -} -} +} // namespace vision +} // namespace fastdeploy #endif diff --git a/fastdeploy/vision/common/processors/manager.cc b/fastdeploy/vision/common/processors/manager.cc index 147e12ae8..45b29866b 100644 --- a/fastdeploy/vision/common/processors/manager.cc +++ b/fastdeploy/vision/common/processors/manager.cc @@ -62,13 +62,24 @@ bool ProcessorManager::Run(std::vector* images, return false; } - for (size_t i = 0; i < images->size(); ++i) { - if (CudaUsed()) { - SetStream(&((*images)[i])); - } + if (images->size() > input_caches_.size()) { + input_caches_.resize(images->size()); + output_caches_.resize(images->size()); } - bool ret = Apply(images, outputs); + FDMatBatch image_batch(images); + image_batch.input_cache = &batch_input_cache_; + image_batch.output_cache = &batch_output_cache_; + + for (size_t i = 0; i < images->size(); ++i) { + if (CudaUsed()) { + SetStream(&image_batch); + } + (*images)[i].input_cache = &input_caches_[i]; + (*images)[i].output_cache = &output_caches_[i]; + } + + bool ret = Apply(&image_batch, outputs); if (CudaUsed()) { SyncStream(); diff --git a/fastdeploy/vision/common/processors/manager.h b/fastdeploy/vision/common/processors/manager.h index 8721c7e10..6c119ff56 100644 --- a/fastdeploy/vision/common/processors/manager.h +++ b/fastdeploy/vision/common/processors/manager.h @@ -16,6 +16,7 @@ #include "fastdeploy/utils/utils.h" #include "fastdeploy/vision/common/processors/mat.h" +#include "fastdeploy/vision/common/processors/mat_batch.h" namespace fastdeploy { namespace vision { @@ -24,16 +25,28 @@ class FASTDEPLOY_DECL ProcessorManager { public: ~ProcessorManager(); + /** \brief Use CUDA to boost the performance of processors + * + * \param[in] enable_cv_cuda ture: use CV-CUDA, false: use CUDA only + * \param[in] gpu_id GPU device id + * \return true if the preprocess successed, otherwise false + */ void UseCuda(bool enable_cv_cuda = false, int gpu_id = -1); bool CudaUsed(); - void SetStream(Mat* mat) { + void SetStream(FDMat* mat) { #ifdef WITH_GPU mat->SetStream(stream_); #endif } + void SetStream(FDMatBatch* mat_batch) { +#ifdef WITH_GPU + mat_batch->SetStream(stream_); +#endif + } + void SyncStream() { #ifdef WITH_GPU FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess, @@ -51,13 +64,13 @@ class FASTDEPLOY_DECL ProcessorManager { */ bool Run(std::vector* images, std::vector* outputs); - /** \brief The body of Run() function which needs to be implemented by a derived class + /** \brief Apply() is the body of Run() function, it needs to be implemented by a derived class * - * \param[in] images The input image data list, all the elements are returned by cv::imread() + * \param[in] image_batch The input image batch * \param[in] outputs The output tensors which will feed in runtime * \return true if the preprocess successed, otherwise false */ - virtual bool Apply(std::vector* images, + virtual bool Apply(FDMatBatch* image_batch, std::vector* outputs) = 0; protected: @@ -68,6 +81,11 @@ class FASTDEPLOY_DECL ProcessorManager { cudaStream_t stream_ = nullptr; #endif int device_id_ = -1; + + std::vector input_caches_; + std::vector output_caches_; + FDTensor batch_input_cache_; + FDTensor batch_output_cache_; }; } // namespace vision diff --git a/fastdeploy/vision/common/processors/manager_pybind.cc b/fastdeploy/vision/common/processors/manager_pybind.cc new file mode 100644 index 000000000..65507cce5 --- /dev/null +++ b/fastdeploy/vision/common/processors/manager_pybind.cc @@ -0,0 +1,41 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindProcessorManager(pybind11::module& m) { + pybind11::class_(m, "ProcessorManager") + .def("run", + [](vision::ProcessorManager& self, + std::vector& im_list) { + std::vector images; + for (size_t i = 0; i < im_list.size(); ++i) { + images.push_back(vision::WrapMat(PyArrayToCvMat(im_list[i]))); + } + std::vector outputs; + if (!self.Run(&images, &outputs)) { + throw std::runtime_error("Failed to process the input data"); + } + if (!self.CudaUsed()) { + for (size_t i = 0; i < outputs.size(); ++i) { + outputs[i].StopSharing(); + } + } + return outputs; + }) + .def("use_cuda", + [](vision::ProcessorManager& self, bool enable_cv_cuda = false, + int gpu_id = -1) { self.UseCuda(enable_cv_cuda, gpu_id); }); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/mat.cc b/fastdeploy/vision/common/processors/mat.cc index 93d11f871..f56d0b585 100644 --- a/fastdeploy/vision/common/processors/mat.cc +++ b/fastdeploy/vision/common/processors/mat.cc @@ -247,5 +247,40 @@ std::vector WrapMat(const std::vector& images) { return mats; } +bool CheckShapeConsistency(std::vector* mats) { + for (size_t i = 1; i < mats->size(); ++i) { + if ((*mats)[i].Channels() != (*mats)[0].Channels() || + (*mats)[i].Width() != (*mats)[0].Width() || + (*mats)[i].Height() != (*mats)[0].Height()) { + return false; + } + } + return true; +} + +FDTensor* CreateCachedGpuInputTensor(Mat* mat) { +#ifdef WITH_GPU + FDTensor* src = mat->Tensor(); + if (src->device == Device::GPU) { + return src; + } else if (src->device == Device::CPU) { + // Mats on CPU, we need copy these tensors from CPU to GPU + FDASSERT(src->Shape().size() == 3, "The CPU tensor must has 3 dims.") + mat->input_cache->Resize(src->Shape(), src->Dtype(), "input_cache", + Device::GPU); + FDASSERT( + cudaMemcpyAsync(mat->input_cache->Data(), src->Data(), src->Nbytes(), + cudaMemcpyHostToDevice, mat->Stream()) == 0, + "[ERROR] Error occurs while copy memory from CPU to GPU."); + return mat->input_cache; + } else { + FDASSERT(false, "FDMat is on unsupported device: %d", src->device); + } +#else + FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); +#endif + return nullptr; +} + } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/mat.h b/fastdeploy/vision/common/processors/mat.h index 568744a04..c29fdd4b2 100644 --- a/fastdeploy/vision/common/processors/mat.h +++ b/fastdeploy/vision/common/processors/mat.h @@ -119,6 +119,11 @@ struct FASTDEPLOY_DECL Mat { void SetChannels(int s) { channels = s; } void SetWidth(int w) { width = w; } void SetHeight(int h) { height = h; } + + // When using CV-CUDA/CUDA, please set input/output cache, + // refer to manager.cc + FDTensor* input_cache = nullptr; + FDTensor* output_cache = nullptr; #ifdef WITH_GPU cudaStream_t Stream() const { return stream; } void SetStream(cudaStream_t s) { stream = s; } @@ -165,5 +170,12 @@ FASTDEPLOY_DECL FDMat WrapMat(const cv::Mat& image); */ FASTDEPLOY_DECL std::vector WrapMat(const std::vector& images); +bool CheckShapeConsistency(std::vector* mats); + +// Create an input tensor on GPU and save into input_cache. +// If the Mat is on GPU, return the mat->Tensor() directly. +// If the Mat is on CPU, then update the input cache tensor and copy the mat's +// CPU tensor to this new GPU input cache tensor. +FDTensor* CreateCachedGpuInputTensor(Mat* mat); } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/mat_batch.cc b/fastdeploy/vision/common/processors/mat_batch.cc new file mode 100644 index 000000000..b73703588 --- /dev/null +++ b/fastdeploy/vision/common/processors/mat_batch.cc @@ -0,0 +1,81 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/vision/common/processors/mat_batch.h" + +namespace fastdeploy { +namespace vision { + +#ifdef WITH_GPU +void FDMatBatch::SetStream(cudaStream_t s) { + stream = s; + for (size_t i = 0; i < mats->size(); ++i) { + (*mats)[i].SetStream(s); + } +} +#endif + +FDTensor* FDMatBatch::Tensor() { + if (has_batched_tensor) { + return &fd_tensor; + } + FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.") + // Each mat has its own tensor, + // to get a batched tensor, we need copy these tensors to a batched tensor + FDTensor* src = (*mats)[0].Tensor(); + auto new_shape = src->Shape(); + new_shape.insert(new_shape.begin(), mats->size()); + input_cache->Resize(new_shape, src->Dtype(), "batch_input_cache", device); + for (size_t i = 0; i < mats->size(); ++i) { + FDASSERT(device == (*mats)[i].Tensor()->device, + "Mats and MatBatch are not on the same device"); + uint8_t* p = reinterpret_cast(input_cache->Data()); + int num_bytes = (*mats)[i].Tensor()->Nbytes(); + FDTensor::CopyBuffer(p + i * num_bytes, (*mats)[i].Tensor()->Data(), + num_bytes, device, false); + } + SetTensor(input_cache); + return &fd_tensor; +} + +void FDMatBatch::SetTensor(FDTensor* tensor) { + fd_tensor.SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(), + tensor->device, tensor->device_id); + has_batched_tensor = true; +} + +FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch) { +#ifdef WITH_GPU + auto mats = mat_batch->mats; + FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.") + FDTensor* src = (*mats)[0].Tensor(); + if (mat_batch->device == Device::GPU) { + return mat_batch->Tensor(); + } else if (mat_batch->device == Device::CPU) { + // Mats on CPU, we need copy them to GPU and then get a batched GPU tensor + for (size_t i = 0; i < mats->size(); ++i) { + FDTensor* tensor = CreateCachedGpuInputTensor(&(*mats)[i]); + (*mats)[i].SetTensor(tensor); + } + return mat_batch->Tensor(); + } else { + FDASSERT(false, "FDMat is on unsupported device: %d", src->device); + } +#else + FDASSERT(false, "FastDeploy didn't compile with WITH_GPU."); +#endif + return nullptr; +} + +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/mat_batch.h b/fastdeploy/vision/common/processors/mat_batch.h new file mode 100644 index 000000000..ed5b408c3 --- /dev/null +++ b/fastdeploy/vision/common/processors/mat_batch.h @@ -0,0 +1,76 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "fastdeploy/vision/common/processors/mat.h" + +#ifdef WITH_GPU +#include +#endif + +namespace fastdeploy { +namespace vision { + +enum FDMatBatchLayout { NHWC, NCHW }; + +struct FASTDEPLOY_DECL FDMatBatch { + FDMatBatch() = default; + + // MatBatch is intialized with a list of mats, + // the data is stored in the mats separately. + // Call Tensor() function to get a batched 4-dimension tensor. + explicit FDMatBatch(std::vector* _mats) { + mats = _mats; + layout = FDMatBatchLayout::NHWC; + mat_type = ProcLib::OPENCV; + } + + // Get the batched 4-dimension tensor. + FDTensor* Tensor(); + + void SetTensor(FDTensor* tensor); + + private: +#ifdef WITH_GPU + cudaStream_t stream = nullptr; +#endif + FDTensor fd_tensor; + + public: + // When using CV-CUDA/CUDA, please set input/output cache, + // refer to manager.cc + FDTensor* input_cache; + FDTensor* output_cache; +#ifdef WITH_GPU + cudaStream_t Stream() const { return stream; } + void SetStream(cudaStream_t s); +#endif + + std::vector* mats; + ProcLib mat_type = ProcLib::OPENCV; + FDMatBatchLayout layout = FDMatBatchLayout::NHWC; + Device device = Device::CPU; + + // False: the data is stored in the mats separately + // True: the data is stored in the fd_tensor continuously in 4 dimensions + bool has_batched_tensor = false; +}; + +// Create a batched input tensor on GPU and save into input_cache. +// If the MatBatch is on GPU, return the Tensor() directly. +// If the MatBatch is on CPU, then copy the CPU tensors to GPU and get a GPU +// batched input tensor. +FDTensor* CreateCachedGpuInputTensor(FDMatBatch* mat_batch); + +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/common/processors/normalize_and_permute.cc b/fastdeploy/vision/common/processors/normalize_and_permute.cc index 93850b97f..d38aeca86 100755 --- a/fastdeploy/vision/common/processors/normalize_and_permute.cc +++ b/fastdeploy/vision/common/processors/normalize_and_permute.cc @@ -56,7 +56,7 @@ NormalizeAndPermute::NormalizeAndPermute(const std::vector& mean, swap_rb_ = swap_rb; } -bool NormalizeAndPermute::ImplByOpenCV(Mat* mat) { +bool NormalizeAndPermute::ImplByOpenCV(FDMat* mat) { cv::Mat* im = mat->GetOpenCVMat(); int origin_w = im->cols; int origin_h = im->rows; @@ -79,7 +79,7 @@ bool NormalizeAndPermute::ImplByOpenCV(Mat* mat) { } #ifdef ENABLE_FLYCV -bool NormalizeAndPermute::ImplByFlyCV(Mat* mat) { +bool NormalizeAndPermute::ImplByFlyCV(FDMat* mat) { if (mat->layout != Layout::HWC) { FDERROR << "Only supports input with HWC layout." << std::endl; return false; @@ -109,7 +109,7 @@ bool NormalizeAndPermute::ImplByFlyCV(Mat* mat) { } #endif -bool NormalizeAndPermute::Run(Mat* mat, const std::vector& mean, +bool NormalizeAndPermute::Run(FDMat* mat, const std::vector& mean, const std::vector& std, bool is_scale, const std::vector& min, const std::vector& max, ProcLib lib, diff --git a/fastdeploy/vision/common/processors/normalize_and_permute.cu b/fastdeploy/vision/common/processors/normalize_and_permute.cu index 69bb6af1d..fd482e9d6 100644 --- a/fastdeploy/vision/common/processors/normalize_and_permute.cu +++ b/fastdeploy/vision/common/processors/normalize_and_permute.cu @@ -18,63 +18,110 @@ namespace fastdeploy { namespace vision { -__global__ void NormalizeAndPermuteKernel(uint8_t* src, float* dst, +__global__ void NormalizeAndPermuteKernel(const uint8_t* src, float* dst, const float* alpha, const float* beta, int num_channel, bool swap_rb, - int edge) { + int batch_size, int edge) { int idx = blockDim.x * blockIdx.x + threadIdx.x; if (idx >= edge) return; - if (swap_rb) { - uint8_t tmp = src[num_channel * idx]; - src[num_channel * idx] = src[num_channel * idx + 2]; - src[num_channel * idx + 2] = tmp; - } + int img_size = edge / batch_size; + int n = idx / img_size; // batch index + int p = idx - (n * img_size); // pixel index within the image for (int i = 0; i < num_channel; ++i) { - dst[idx + edge * i] = src[num_channel * idx + i] * alpha[i] + beta[i]; + int j = i; + if (swap_rb) { + j = 2 - i; + } + dst[n * img_size * num_channel + i * img_size + p] = + src[num_channel * idx + j] * alpha[i] + beta[i]; } } -bool NormalizeAndPermute::ImplByCuda(Mat* mat) { +bool NormalizeAndPermute::ImplByCuda(FDMat* mat) { // Prepare input tensor - std::string tensor_name = Name() + "_cvcuda_src"; - FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name); + FDTensor* src = CreateCachedGpuInputTensor(mat); // Prepare output tensor - tensor_name = Name() + "_dst"; - FDTensor* dst = UpdateAndGetCachedTensor(src->Shape(), FDDataType::FP32, - tensor_name, Device::GPU); + mat->output_cache->Resize(src->Shape(), FDDataType::FP32, "output_cache", + Device::GPU); // Copy alpha and beta to GPU - tensor_name = Name() + "_alpha"; - FDMat alpha_mat = - FDMat::Create(1, 1, alpha_.size(), FDDataType::FP32, alpha_.data()); - FDTensor* alpha = CreateCachedGpuInputTensor(&alpha_mat, tensor_name); + gpu_alpha_.Resize({1, 1, static_cast(alpha_.size())}, FDDataType::FP32, + "alpha", Device::GPU); + cudaMemcpy(gpu_alpha_.Data(), alpha_.data(), gpu_alpha_.Nbytes(), + cudaMemcpyHostToDevice); - tensor_name = Name() + "_beta"; - FDMat beta_mat = - FDMat::Create(1, 1, beta_.size(), FDDataType::FP32, beta_.data()); - FDTensor* beta = CreateCachedGpuInputTensor(&beta_mat, tensor_name); + gpu_beta_.Resize({1, 1, static_cast(beta_.size())}, FDDataType::FP32, + "beta", Device::GPU); + cudaMemcpy(gpu_beta_.Data(), beta_.data(), gpu_beta_.Nbytes(), + cudaMemcpyHostToDevice); - int jobs = mat->Width() * mat->Height(); + int jobs = 1 * mat->Width() * mat->Height(); int threads = 256; int blocks = ceil(jobs / (float)threads); NormalizeAndPermuteKernel<<Stream()>>>( reinterpret_cast(src->Data()), - reinterpret_cast(dst->Data()), - reinterpret_cast(alpha->Data()), - reinterpret_cast(beta->Data()), mat->Channels(), swap_rb_, jobs); + reinterpret_cast(mat->output_cache->Data()), + reinterpret_cast(gpu_alpha_.Data()), + reinterpret_cast(gpu_beta_.Data()), mat->Channels(), swap_rb_, 1, + jobs); - mat->SetTensor(dst); + mat->SetTensor(mat->output_cache); mat->device = Device::GPU; mat->layout = Layout::CHW; mat->mat_type = ProcLib::CUDA; return true; } +bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) { + // Prepare input tensor + FDTensor* src = CreateCachedGpuInputTensor(mat_batch); + + // Prepare output tensor + mat_batch->output_cache->Resize(src->Shape(), FDDataType::FP32, + "output_cache", Device::GPU); + // NHWC -> NCHW + std::swap(mat_batch->output_cache->shape[1], + mat_batch->output_cache->shape[3]); + + // Copy alpha and beta to GPU + gpu_alpha_.Resize({1, 1, static_cast(alpha_.size())}, FDDataType::FP32, + "alpha", Device::GPU); + cudaMemcpy(gpu_alpha_.Data(), alpha_.data(), gpu_alpha_.Nbytes(), + cudaMemcpyHostToDevice); + + gpu_beta_.Resize({1, 1, static_cast(beta_.size())}, FDDataType::FP32, + "beta", Device::GPU); + cudaMemcpy(gpu_beta_.Data(), beta_.data(), gpu_beta_.Nbytes(), + cudaMemcpyHostToDevice); + + int jobs = + mat_batch->output_cache->Numel() / mat_batch->output_cache->shape[1]; + int threads = 256; + int blocks = ceil(jobs / (float)threads); + NormalizeAndPermuteKernel<<Stream()>>>( + reinterpret_cast(src->Data()), + reinterpret_cast(mat_batch->output_cache->Data()), + reinterpret_cast(gpu_alpha_.Data()), + reinterpret_cast(gpu_beta_.Data()), + mat_batch->output_cache->shape[1], swap_rb_, + mat_batch->output_cache->shape[0], jobs); + + mat_batch->SetTensor(mat_batch->output_cache); + mat_batch->device = Device::GPU; + mat_batch->layout = FDMatBatchLayout::NCHW; + mat_batch->mat_type = ProcLib::CUDA; + return true; +} + #ifdef ENABLE_CVCUDA -bool NormalizeAndPermute::ImplByCvCuda(Mat* mat) { return ImplByCuda(mat); } +bool NormalizeAndPermute::ImplByCvCuda(FDMat* mat) { return ImplByCuda(mat); } + +bool NormalizeAndPermute::ImplByCvCuda(FDMatBatch* mat_batch) { + return ImplByCuda(mat_batch); +} #endif } // namespace vision diff --git a/fastdeploy/vision/common/processors/normalize_and_permute.h b/fastdeploy/vision/common/processors/normalize_and_permute.h index ff8394c67..da7039db4 100644 --- a/fastdeploy/vision/common/processors/normalize_and_permute.h +++ b/fastdeploy/vision/common/processors/normalize_and_permute.h @@ -25,15 +25,17 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor { const std::vector& min = std::vector(), const std::vector& max = std::vector(), bool swap_rb = false); - bool ImplByOpenCV(Mat* mat); + bool ImplByOpenCV(FDMat* mat); #ifdef ENABLE_FLYCV - bool ImplByFlyCV(Mat* mat); + bool ImplByFlyCV(FDMat* mat); #endif #ifdef WITH_GPU - bool ImplByCuda(Mat* mat); + bool ImplByCuda(FDMat* mat); + bool ImplByCuda(FDMatBatch* mat_batch); #endif #ifdef ENABLE_CVCUDA - bool ImplByCvCuda(Mat* mat); + bool ImplByCvCuda(FDMat* mat); + bool ImplByCvCuda(FDMatBatch* mat_batch); #endif std::string Name() { return "NormalizeAndPermute"; } @@ -47,7 +49,7 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor { // There will be some precomputation in contruct function // and the `norm(mat)` only need to compute result = mat * alpha + beta // which will reduce lots of time - static bool Run(Mat* mat, const std::vector& mean, + static bool Run(FDMat* mat, const std::vector& mean, const std::vector& std, bool is_scale = true, const std::vector& min = std::vector(), const std::vector& max = std::vector(), @@ -76,6 +78,8 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor { private: std::vector alpha_; std::vector beta_; + FDTensor gpu_alpha_; + FDTensor gpu_beta_; bool swap_rb_; }; } // namespace vision diff --git a/fastdeploy/vision/common/processors/resize.cc b/fastdeploy/vision/common/processors/resize.cc index 29a8798ad..0de6ddfc7 100644 --- a/fastdeploy/vision/common/processors/resize.cc +++ b/fastdeploy/vision/common/processors/resize.cc @@ -23,7 +23,7 @@ namespace fastdeploy { namespace vision { -bool Resize::ImplByOpenCV(Mat* mat) { +bool Resize::ImplByOpenCV(FDMat* mat) { if (mat->layout != Layout::HWC) { FDERROR << "Resize: The format of input is not HWC." << std::endl; return false; @@ -61,7 +61,7 @@ bool Resize::ImplByOpenCV(Mat* mat) { } #ifdef ENABLE_FLYCV -bool Resize::ImplByFlyCV(Mat* mat) { +bool Resize::ImplByFlyCV(FDMat* mat) { if (mat->layout != Layout::HWC) { FDERROR << "Resize: The format of input is not HWC." << std::endl; return false; @@ -123,7 +123,7 @@ bool Resize::ImplByFlyCV(Mat* mat) { #endif #ifdef ENABLE_CVCUDA -bool Resize::ImplByCvCuda(Mat* mat) { +bool Resize::ImplByCvCuda(FDMat* mat) { if (width_ == mat->Width() && height_ == mat->Height()) { return true; } @@ -143,23 +143,20 @@ bool Resize::ImplByCvCuda(Mat* mat) { } // Prepare input tensor - std::string tensor_name = Name() + "_cvcuda_src"; - FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name); + FDTensor* src = CreateCachedGpuInputTensor(mat); auto src_tensor = CreateCvCudaTensorWrapData(*src); // Prepare output tensor - tensor_name = Name() + "_cvcuda_dst"; - FDTensor* dst = - UpdateAndGetCachedTensor({height_, width_, mat->Channels()}, mat->Type(), - tensor_name, Device::GPU); - auto dst_tensor = CreateCvCudaTensorWrapData(*dst); + mat->output_cache->Resize({height_, width_, mat->Channels()}, mat->Type(), + "output_cache", Device::GPU); + auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); // CV-CUDA Interp value is compatible with OpenCV cvcuda::Resize resize_op; resize_op(mat->Stream(), src_tensor, dst_tensor, NVCVInterpolationType(interp_)); - mat->SetTensor(dst); + mat->SetTensor(mat->output_cache); mat->SetWidth(width_); mat->SetHeight(height_); mat->device = Device::GPU; @@ -168,8 +165,8 @@ bool Resize::ImplByCvCuda(Mat* mat) { } #endif -bool Resize::Run(Mat* mat, int width, int height, float scale_w, float scale_h, - int interp, bool use_scale, ProcLib lib) { +bool Resize::Run(FDMat* mat, int width, int height, float scale_w, + float scale_h, int interp, bool use_scale, ProcLib lib) { if (mat->Height() == height && mat->Width() == width) { return true; } diff --git a/fastdeploy/vision/common/processors/resize.h b/fastdeploy/vision/common/processors/resize.h index 54480108b..2b4f88a35 100644 --- a/fastdeploy/vision/common/processors/resize.h +++ b/fastdeploy/vision/common/processors/resize.h @@ -31,16 +31,16 @@ class FASTDEPLOY_DECL Resize : public Processor { use_scale_ = use_scale; } - bool ImplByOpenCV(Mat* mat); + bool ImplByOpenCV(FDMat* mat); #ifdef ENABLE_FLYCV - bool ImplByFlyCV(Mat* mat); + bool ImplByFlyCV(FDMat* mat); #endif #ifdef ENABLE_CVCUDA - bool ImplByCvCuda(Mat* mat); + bool ImplByCvCuda(FDMat* mat); #endif std::string Name() { return "Resize"; } - static bool Run(Mat* mat, int width, int height, float scale_w = -1.0, + static bool Run(FDMat* mat, int width, int height, float scale_w = -1.0, float scale_h = -1.0, int interp = 1, bool use_scale = false, ProcLib lib = ProcLib::DEFAULT); diff --git a/fastdeploy/vision/common/processors/resize_by_short.cc b/fastdeploy/vision/common/processors/resize_by_short.cc index 1d6309f5d..535652fc7 100644 --- a/fastdeploy/vision/common/processors/resize_by_short.cc +++ b/fastdeploy/vision/common/processors/resize_by_short.cc @@ -23,7 +23,7 @@ namespace fastdeploy { namespace vision { -bool ResizeByShort::ImplByOpenCV(Mat* mat) { +bool ResizeByShort::ImplByOpenCV(FDMat* mat) { cv::Mat* im = mat->GetOpenCVMat(); int origin_w = im->cols; int origin_h = im->rows; @@ -43,7 +43,7 @@ bool ResizeByShort::ImplByOpenCV(Mat* mat) { } #ifdef ENABLE_FLYCV -bool ResizeByShort::ImplByFlyCV(Mat* mat) { +bool ResizeByShort::ImplByFlyCV(FDMat* mat) { fcv::Mat* im = mat->GetFlyCVMat(); int origin_w = im->width(); int origin_h = im->height(); @@ -87,10 +87,9 @@ bool ResizeByShort::ImplByFlyCV(Mat* mat) { #endif #ifdef ENABLE_CVCUDA -bool ResizeByShort::ImplByCvCuda(Mat* mat) { +bool ResizeByShort::ImplByCvCuda(FDMat* mat) { // Prepare input tensor - std::string tensor_name = Name() + "_cvcuda_src"; - FDTensor* src = CreateCachedGpuInputTensor(mat, tensor_name); + FDTensor* src = CreateCachedGpuInputTensor(mat); auto src_tensor = CreateCvCudaTensorWrapData(*src); double scale = GenerateScale(mat->Width(), mat->Height()); @@ -98,23 +97,69 @@ bool ResizeByShort::ImplByCvCuda(Mat* mat) { int height = static_cast(round(scale * mat->Height())); // Prepare output tensor - tensor_name = Name() + "_cvcuda_dst"; - FDTensor* dst = UpdateAndGetCachedTensor( - {height, width, mat->Channels()}, mat->Type(), tensor_name, Device::GPU); - auto dst_tensor = CreateCvCudaTensorWrapData(*dst); + mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(), + "output_cache", Device::GPU); + auto dst_tensor = CreateCvCudaTensorWrapData(*(mat->output_cache)); // CV-CUDA Interp value is compatible with OpenCV cvcuda::Resize resize_op; resize_op(mat->Stream(), src_tensor, dst_tensor, NVCVInterpolationType(interp_)); - mat->SetTensor(dst); + mat->SetTensor(mat->output_cache); mat->SetWidth(width); mat->SetHeight(height); mat->device = Device::GPU; mat->mat_type = ProcLib::CVCUDA; return true; } + +bool ResizeByShort::ImplByCvCuda(FDMatBatch* mat_batch) { + // TODO(wangxinyu): to support batched tensor as input + FDASSERT(mat_batch->has_batched_tensor == false, + "ResizeByShort doesn't support batched tensor as input for now."); + // Prepare input batch + std::string tensor_name = Name() + "_cvcuda_src"; + std::vector src_tensors; + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + FDTensor* src = CreateCachedGpuInputTensor(&(*(mat_batch->mats))[i]); + src_tensors.push_back(src); + } + nvcv::ImageBatchVarShape src_batch(mat_batch->mats->size()); + CreateCvCudaImageBatchVarShape(src_tensors, src_batch); + + // Prepare output batch + tensor_name = Name() + "_cvcuda_dst"; + std::vector dst_tensors; + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + FDMat* mat = &(*(mat_batch->mats))[i]; + double scale = GenerateScale(mat->Width(), mat->Height()); + int width = static_cast(round(scale * mat->Width())); + int height = static_cast(round(scale * mat->Height())); + mat->output_cache->Resize({height, width, mat->Channels()}, mat->Type(), + "output_cache", Device::GPU); + dst_tensors.push_back(mat->output_cache); + } + nvcv::ImageBatchVarShape dst_batch(mat_batch->mats->size()); + CreateCvCudaImageBatchVarShape(dst_tensors, dst_batch); + + // CV-CUDA Interp value is compatible with OpenCV + cvcuda::Resize resize_op; + resize_op(mat_batch->Stream(), src_batch, dst_batch, + NVCVInterpolationType(interp_)); + + for (size_t i = 0; i < mat_batch->mats->size(); ++i) { + FDMat* mat = &(*(mat_batch->mats))[i]; + mat->SetTensor(dst_tensors[i]); + mat->SetWidth(dst_tensors[i]->Shape()[1]); + mat->SetHeight(dst_tensors[i]->Shape()[0]); + mat->device = Device::GPU; + mat->mat_type = ProcLib::CVCUDA; + } + mat_batch->device = Device::GPU; + mat_batch->mat_type = ProcLib::CVCUDA; + return true; +} #endif double ResizeByShort::GenerateScale(const int origin_w, const int origin_h) { @@ -143,7 +188,7 @@ double ResizeByShort::GenerateScale(const int origin_w, const int origin_h) { return scale; } -bool ResizeByShort::Run(Mat* mat, int target_size, int interp, bool use_scale, +bool ResizeByShort::Run(FDMat* mat, int target_size, int interp, bool use_scale, const std::vector& max_hw, ProcLib lib) { auto r = ResizeByShort(target_size, interp, use_scale, max_hw); return r(mat, lib); diff --git a/fastdeploy/vision/common/processors/resize_by_short.h b/fastdeploy/vision/common/processors/resize_by_short.h index 64a7f09f0..99078c708 100644 --- a/fastdeploy/vision/common/processors/resize_by_short.h +++ b/fastdeploy/vision/common/processors/resize_by_short.h @@ -28,16 +28,17 @@ class FASTDEPLOY_DECL ResizeByShort : public Processor { interp_ = interp; use_scale_ = use_scale; } - bool ImplByOpenCV(Mat* mat); + bool ImplByOpenCV(FDMat* mat); #ifdef ENABLE_FLYCV - bool ImplByFlyCV(Mat* mat); + bool ImplByFlyCV(FDMat* mat); #endif #ifdef ENABLE_CVCUDA - bool ImplByCvCuda(Mat* mat); + bool ImplByCvCuda(FDMat* mat); + bool ImplByCvCuda(FDMatBatch* mat_batch); #endif std::string Name() { return "ResizeByShort"; } - static bool Run(Mat* mat, int target_size, int interp = 1, + static bool Run(FDMat* mat, int target_size, int interp = 1, bool use_scale = true, const std::vector& max_hw = std::vector(), ProcLib lib = ProcLib::DEFAULT); diff --git a/fastdeploy/vision/vision_pybind.cc b/fastdeploy/vision/vision_pybind.cc index 22f7581be..03e625728 100755 --- a/fastdeploy/vision/vision_pybind.cc +++ b/fastdeploy/vision/vision_pybind.cc @@ -16,6 +16,7 @@ namespace fastdeploy { +void BindProcessorManager(pybind11::module& m); void BindDetection(pybind11::module& m); void BindClassification(pybind11::module& m); void BindSegmentation(pybind11::module& m); @@ -204,6 +205,7 @@ void BindVision(pybind11::module& m) { m.def("disable_flycv", &vision::DisableFlyCV, "Disable image preprocessing by FlyCV, change to use OpenCV."); + BindProcessorManager(m); BindDetection(m); BindClassification(m); BindSegmentation(m); diff --git a/python/fastdeploy/vision/classification/ppcls/__init__.py b/python/fastdeploy/vision/classification/ppcls/__init__.py index 455702271..7215bcfbc 100644 --- a/python/fastdeploy/vision/classification/ppcls/__init__.py +++ b/python/fastdeploy/vision/classification/ppcls/__init__.py @@ -16,44 +16,40 @@ from __future__ import absolute_import import logging from .... import FastDeployModel, ModelFormat from .... import c_lib_wrap as C +from ...common import ProcessorManager -class PaddleClasPreprocessor: +class PaddleClasPreprocessor(ProcessorManager): def __init__(self, config_file): """Create a preprocessor for PaddleClasModel from configuration file :param config_file: (str)Path of configuration file, e.g resnet50/inference_cls.yaml """ - self._preprocessor = C.vision.classification.PaddleClasPreprocessor( + super(PaddleClasPreprocessor, self).__init__() + self._manager = C.vision.classification.PaddleClasPreprocessor( config_file) - def run(self, input_ims): - """Preprocess input images for PaddleClasModel - - :param: input_ims: (list of numpy.ndarray)The input image - :return: list of FDTensor - """ - return self._preprocessor.run(input_ims) - - def use_cuda(self, enable_cv_cuda=False, gpu_id=-1): - """Use CUDA preprocessors - - :param: enable_cv_cuda: Whether to enable CV-CUDA - :param: gpu_id: GPU device id - """ - return self._preprocessor.use_cuda(enable_cv_cuda, gpu_id) - def disable_normalize(self): """ This function will disable normalize in preprocessing step. """ - self._preprocessor.disable_normalize() + self._manager.disable_normalize() def disable_permute(self): """ This function will disable hwc2chw in preprocessing step. """ - self._preprocessor.disable_permute() + self._manager.disable_permute() + + def initial_resize_on_cpu(self, v): + """ + When the initial operator is Resize, and input image size is large, + maybe it's better to run resize on CPU, because the HostToDevice memcpy + is time consuming. Set this True to run the initial resize on CPU. + + :param: v: True or False + """ + self._manager.initial_resize_on_cpu(v) class PaddleClasPostprocessor: diff --git a/python/fastdeploy/vision/common/__init__.py b/python/fastdeploy/vision/common/__init__.py new file mode 100644 index 000000000..6e010a427 --- /dev/null +++ b/python/fastdeploy/vision/common/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import + +from .manager import ProcessorManager diff --git a/python/fastdeploy/vision/common/manager.py b/python/fastdeploy/vision/common/manager.py new file mode 100644 index 000000000..05da3d68e --- /dev/null +++ b/python/fastdeploy/vision/common/manager.py @@ -0,0 +1,36 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + + +class ProcessorManager: + def __init__(self): + self._manager = None + + def run(self, input_ims): + """Process input image + + :param: input_ims: (list of numpy.ndarray) The input images + :return: list of FDTensor + """ + return self._manager.run(input_ims) + + def use_cuda(self, enable_cv_cuda=False, gpu_id=-1): + """Use CUDA processors + + :param: enable_cv_cuda: Ture: use CV-CUDA, False: use CUDA only + :param: gpu_id: GPU device id + """ + return self._manager.use_cuda(enable_cv_cuda, gpu_id) diff --git a/python/setup.py b/python/setup.py index 01246283a..6d88a87e1 100755 --- a/python/setup.py +++ b/python/setup.py @@ -72,6 +72,7 @@ setup_configs["PADDLELITE_URL"] = os.getenv("PADDLELITE_URL", "OFF") setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF") setup_configs["ENABLE_ENCRYPTION"] = os.getenv("ENABLE_ENCRYPTION", "OFF") setup_configs["ENABLE_FLYCV"] = os.getenv("ENABLE_FLYCV", "OFF") +setup_configs["ENABLE_CVCUDA"] = os.getenv("ENABLE_CVCUDA", "OFF") setup_configs["ENABLE_TEXT"] = os.getenv("ENABLE_TEXT", "OFF") setup_configs["ENABLE_BENCHMARK"] = os.getenv("ENABLE_BENCHMARK", "OFF") setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF")