[Backend] cuda normalize and permute, cuda concat, optimized ppcls, ppdet & ppseg (#546)

* cuda normalize and permute, cuda concat

* add use cuda option for preprocessor

* ppyoloe use cuda normalize

* ppseg use cuda normalize

* add proclib cuda in processor base

* ppcls add use cuda preprocess api

* ppcls preprocessor set gpu id

* fix pybind

* refine ppcls preprocessing use gpu logic

* fdtensor device id is -1 by default

* refine assert message

Co-authored-by: heliqi <1101791222@qq.com>
This commit is contained in:
Wang Xinyu
2022-11-14 18:44:00 +08:00
committed by GitHub
parent 8dec2115d5
commit a36f5d3396
20 changed files with 204 additions and 26 deletions

View File

@@ -252,7 +252,8 @@ void FDTensor::FreeFn() {
}
}
void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes) {
void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes,
const Device& device, bool is_pinned_memory) {
if (device == Device::GPU) {
#ifdef WITH_GPU
FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToDevice) == 0,
@@ -295,7 +296,7 @@ FDTensor::FDTensor(const FDTensor& other)
size_t nbytes = Nbytes();
FDASSERT(ReallocFn(nbytes),
"The FastDeploy FDTensor allocate memory error");
CopyBuffer(buffer_, other.buffer_, nbytes);
CopyBuffer(buffer_, other.buffer_, nbytes, device, is_pinned_memory);
}
}
@@ -325,7 +326,7 @@ FDTensor& FDTensor::operator=(const FDTensor& other) {
} else {
Resize(other.shape, other.dtype, other.name, other.device);
size_t nbytes = Nbytes();
CopyBuffer(buffer_, other.buffer_, nbytes);
CopyBuffer(buffer_, other.buffer_, nbytes, device, is_pinned_memory);
}
external_data_ptr = other.external_data_ptr;
}

View File

@@ -39,6 +39,9 @@ struct FASTDEPLOY_DECL FDTensor {
// GPU to inference the model
// so we can skip data transfer, which may improve the efficience
Device device = Device::CPU;
// By default the device id of FDTensor is -1, which means this value is
// invalid, and FDTensor is using the same device id as Runtime.
int device_id = -1;
// Whether the data buffer is in pinned memory, which is allocated
// with cudaMallocHost()
@@ -130,8 +133,9 @@ struct FASTDEPLOY_DECL FDTensor {
~FDTensor() { FreeFn(); }
private:
void CopyBuffer(void* dst, const void* src, size_t nbytes);
static void CopyBuffer(void* dst, const void* src, size_t nbytes,
const Device& device = Device::CPU,
bool is_pinned_memory = false);
};
} // namespace fastdeploy

View File

@@ -85,8 +85,9 @@ struct ConcatFunctor {
int64_t col_len = input_cols[j];
const T* input_data = reinterpret_cast<const T*>(input[j].Data());
for (int64_t k = 0; k < out_rows; ++k) {
std::memcpy(output_data + k * out_cols + col_idx,
input_data + k * col_len, sizeof(T) * col_len);
FDTensor::CopyBuffer(output_data + k * out_cols + col_idx,
input_data + k * col_len, sizeof(T) * col_len,
input[j].device, input[j].is_pinned_memory);
}
col_idx += col_len;
}
@@ -97,7 +98,8 @@ template <typename T>
void ConcatKernel(const std::vector<FDTensor>& input, FDTensor* output,
int axis) {
auto output_shape = ComputeAndCheckConcatOutputShape(input, axis);
output->Allocate(output_shape, TypeToDataType<T>::dtype);
output->Resize(output_shape, TypeToDataType<T>::dtype, output->name,
input[0].device);
ConcatFunctor<T> functor;
functor(input, axis, output);
@@ -115,10 +117,9 @@ void Concat(const std::vector<FDTensor>& x, FDTensor* out, int axis) {
if (axis < 0) {
axis += rank;
}
FDTensor out_temp;
FD_VISIT_ALL_TYPES(x[0].dtype, "Concat",
([&] { ConcatKernel<data_t>(x, &out_temp, axis); }));
*out = std::move(out_temp);
([&] { ConcatKernel<data_t>(x, out, axis); }));
}
} // namespace function

View File

@@ -568,6 +568,11 @@ std::vector<TensorInfo> Runtime::GetOutputInfos() {
bool Runtime::Infer(std::vector<FDTensor>& input_tensors,
std::vector<FDTensor>* output_tensors) {
for (auto& tensor: input_tensors) {
FDASSERT(tensor.device_id < 0 || tensor.device_id == option.device_id,
"Device id of input tensor(%d) and runtime(%d) are not same.",
tensor.device_id, option.device_id);
}
return backend_->Infer(input_tensors, output_tensors);
}

0
fastdeploy/vision/classification/ppcls/model.cc Executable file → Normal file
View File

View File

@@ -28,6 +28,9 @@ void BindPaddleClas(pybind11::module& m) {
pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleClasPreprocessor.')");
}
return outputs;
})
.def("use_gpu", [](vision::classification::PaddleClasPreprocessor& self, int gpu_id = -1) {
self.UseGpu(gpu_id);
});
pybind11::class_<vision::classification::PaddleClasPostprocessor>(

View File

@@ -15,17 +15,22 @@
#include "fastdeploy/vision/classification/ppcls/preprocessor.h"
#include "fastdeploy/function/concat.h"
#include "yaml-cpp/yaml.h"
#ifdef WITH_GPU
#include <cuda_runtime_api.h>
#endif
namespace fastdeploy {
namespace vision {
namespace classification {
PaddleClasPreprocessor::PaddleClasPreprocessor(const std::string& config_file) {
FDASSERT(BuildPreprocessPipelineFromConfig(config_file), "Failed to create PaddleClasPreprocessor.");
FDASSERT(BuildPreprocessPipelineFromConfig(config_file),
"Failed to create PaddleClasPreprocessor.");
initialized_ = true;
}
bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& config_file) {
bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(
const std::string& config_file) {
processors_.clear();
YAML::Node cfg;
try {
@@ -73,6 +78,19 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(const std::string
return true;
}
void PaddleClasPreprocessor::UseGpu(int gpu_id) {
#ifdef WITH_GPU
use_cuda_ = true;
if (gpu_id < 0) return;
device_id_ = gpu_id;
cudaSetDevice(device_id_);
#else
FDWARNING << "FastDeploy didn't compile with WITH_GPU. "
<< "Will force to use CPU to run preprocessing." << std::endl;
use_cuda_ = false;
#endif
}
bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
if (!initialized_) {
FDERROR << "The preprocessor is not initialized." << std::endl;
@@ -85,8 +103,15 @@ bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTenso
for (size_t i = 0; i < images->size(); ++i) {
for (size_t j = 0; j < processors_.size(); ++j) {
if (!(*(processors_[j].get()))(&((*images)[i]))) {
FDERROR << "Failed to processs image:" << i << " in " << processors_[i]->Name() << "." << std::endl;
bool ret = false;
if (processors_[j]->Name() == "NormalizeAndPermute" && use_cuda_) {
ret = (*(processors_[j].get()))(&((*images)[i]), ProcLib::CUDA);
} else {
ret = (*(processors_[j].get()))(&((*images)[i]));
}
if (!ret) {
FDERROR << "Failed to processs image:" << i << " in "
<< processors_[i]->Name() << "." << std::endl;
return false;
}
}
@@ -104,6 +129,7 @@ bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTenso
} else {
function::Concat(tensors, &((*outputs)[0]), 0);
}
(*outputs)[0].device_id = device_id_;
return true;
}

View File

@@ -38,11 +38,19 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor {
*/
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
/** \brief Use GPU to run preprocessing
*
* \param[in] gpu_id GPU device id
*/
void UseGpu(int gpu_id = -1);
private:
bool BuildPreprocessPipelineFromConfig(const std::string& config_file);
std::vector<std::shared_ptr<Processor>> processors_;
bool initialized_ = false;
bool use_cuda_ = false;
// GPU device id
int device_id_ = -1;
};
} // namespace classification

View File

@@ -30,12 +30,32 @@ bool Processor::operator()(Mat* mat, ProcLib lib) {
return ImplByFlyCV(mat);
#else
FDASSERT(false, "FastDeploy didn't compile with FlyCV.");
#endif
} else if (target == ProcLib::CUDA) {
#ifdef WITH_GPU
return ImplByCuda(mat);
#else
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
#endif
}
// DEFAULT & OPENCV
return ImplByOpenCV(mat);
}
FDTensor* Processor::UpdateAndGetReusedBuffer(
const std::vector<int64_t>& new_shape, const int& opencv_dtype,
const std::string& buffer_name, const Device& new_device,
const bool& use_pinned_memory) {
if (reused_buffers_.count(buffer_name) == 0) {
reused_buffers_[buffer_name] = FDTensor();
}
reused_buffers_[buffer_name].is_pinned_memory = use_pinned_memory;
reused_buffers_[buffer_name].Resize(new_shape,
OpenCVDataTypeToFD(opencv_dtype),
buffer_name, new_device);
return &reused_buffers_[buffer_name];
}
void EnableFlyCV() {
#ifdef ENABLE_FLYCV
DefaultProcLib::default_lib = ProcLib::FLYCV;

View File

@@ -18,6 +18,7 @@
#include "fastdeploy/vision/common/processors/mat.h"
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
#include <unordered_map>
namespace fastdeploy {
namespace vision {
@@ -55,7 +56,20 @@ class FASTDEPLOY_DECL Processor {
return ImplByOpenCV(mat);
}
virtual bool ImplByCuda(Mat* mat) {
return ImplByOpenCV(mat);
}
virtual bool operator()(Mat* mat, ProcLib lib = ProcLib::DEFAULT);
protected:
FDTensor* UpdateAndGetReusedBuffer(
const std::vector<int64_t>& new_shape, const int& opencv_dtype,
const std::string& buffer_name, const Device& new_device = Device::CPU,
const bool& use_pinned_memory = false);
private:
std::unordered_map<std::string, FDTensor> reused_buffers_;
};
} // namespace vision

View File

@@ -34,7 +34,7 @@ void* Mat::Data() {
void Mat::ShareWithTensor(FDTensor* tensor) {
tensor->SetExternalData({Channels(), Height(), Width()}, Type(), Data());
tensor->device = Device::CPU;
tensor->device = device;
if (layout == Layout::HWC) {
tensor->shape = {Height(), Width(), Channels()};
}

View File

@@ -131,6 +131,7 @@ struct FASTDEPLOY_DECL Mat {
ProcLib mat_type = ProcLib::OPENCV;
Layout layout = Layout::HWC;
Device device = Device::CPU;
// Create FD Mat from FD Tensor. This method only create a
// new FD Mat with zero copy and it's data pointer is reference

View File

@@ -73,7 +73,6 @@ bool NormalizeAndPermute::ImplByOpenCV(Mat* mat) {
res.ptr() + i * origin_h * origin_w * 4),
0);
}
mat->SetMat(res);
mat->layout = Layout::CHW;
return true;

View File

@@ -0,0 +1,82 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/common/processors/normalize_and_permute.h"
namespace fastdeploy {
namespace vision {
__global__ void NormalizeAndPermuteKernel(
uint8_t* src, float* dst, const float* alpha, const float* beta,
int num_channel, bool swap_rb, int edge) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx >= edge) return;
if (swap_rb) {
uint8_t tmp = src[num_channel * idx];
src[num_channel * idx] = src[num_channel * idx + 2];
src[num_channel * idx + 2] = tmp;
}
for (int i = 0; i < num_channel; ++i) {
dst[idx + edge * i] = src[num_channel * idx + i] * alpha[i] + beta[i];
}
}
bool NormalizeAndPermute::ImplByCuda(Mat* mat) {
cv::Mat* im = mat->GetOpenCVMat();
std::string buf_name = Name() + "_src";
std::vector<int64_t> shape = {im->rows, im->cols, im->channels()};
FDTensor* src = UpdateAndGetReusedBuffer(shape, im->type(), buf_name,
Device::GPU);
FDASSERT(cudaMemcpy(src->Data(), im->ptr(), src->Nbytes(),
cudaMemcpyHostToDevice) == 0,
"Error occurs while copy memory from CPU to GPU.");
buf_name = Name() + "_dst";
FDTensor* dst = UpdateAndGetReusedBuffer(shape, CV_32FC(im->channels()),
buf_name, Device::GPU);
cv::Mat res(im->rows, im->cols, CV_32FC(im->channels()), dst->Data());
buf_name = Name() + "_alpha";
FDTensor* alpha = UpdateAndGetReusedBuffer({(int64_t)alpha_.size()}, CV_32FC1,
buf_name, Device::GPU);
FDASSERT(cudaMemcpy(alpha->Data(), alpha_.data(), alpha->Nbytes(),
cudaMemcpyHostToDevice) == 0,
"Error occurs while copy memory from CPU to GPU.");
buf_name = Name() + "_beta";
FDTensor* beta = UpdateAndGetReusedBuffer({(int64_t)beta_.size()}, CV_32FC1,
buf_name, Device::GPU);
FDASSERT(cudaMemcpy(beta->Data(), beta_.data(), beta->Nbytes(),
cudaMemcpyHostToDevice) == 0,
"Error occurs while copy memory from CPU to GPU.");
int jobs = im->cols * im->rows;
int threads = 256;
int blocks = ceil(jobs / (float)threads);
NormalizeAndPermuteKernel<<<blocks, threads, 0, NULL>>>(
reinterpret_cast<uint8_t*>(src->Data()),
reinterpret_cast<float*>(dst->Data()),
reinterpret_cast<float*>(alpha->Data()),
reinterpret_cast<float*>(beta->Data()), im->channels(), swap_rb_, jobs);
mat->SetMat(res);
mat->device = Device::GPU;
mat->layout = Layout::CHW;
return true;
}
} // namespace vision
} // namespace fastdeploy

View File

@@ -28,6 +28,9 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
bool ImplByOpenCV(Mat* mat);
#ifdef ENABLE_FLYCV
bool ImplByFlyCV(Mat* mat);
#endif
#ifdef WITH_GPU
bool ImplByCuda(Mat* mat);
#endif
std::string Name() { return "NormalizeAndPermute"; }

View File

@@ -30,6 +30,9 @@ std::ostream& operator<<(std::ostream& out, const ProcLib& p) {
case ProcLib::FLYCV:
out << "ProcLib::FLYCV";
break;
case ProcLib::CUDA:
out << "ProcLib::CUDA";
break;
default:
FDASSERT(false, "Unknow type of ProcLib.");
}

View File

@@ -18,7 +18,7 @@
namespace fastdeploy {
namespace vision {
enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV };
enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV, CUDA };
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, const ProcLib& p);

View File

@@ -146,7 +146,6 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() {
// Fusion will improve performance
FuseTransforms(&processors_);
return true;
}
@@ -161,8 +160,6 @@ bool PPYOLOE::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
}
}
Cast::Run(mat, "float");
outputs->resize(2);
(*outputs)[0].name = InputInfoOfRuntime(0).name;
mat->ShareWithTensor(&((*outputs)[0]));

View File

@@ -39,6 +39,8 @@ PaddleSegModel::PaddleSegModel(const std::string& model_file,
}
bool PaddleSegModel::Initialize() {
reused_input_tensors_.resize(1);
reused_output_tensors_.resize(1);
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
@@ -133,6 +135,9 @@ bool PaddleSegModel::BuildPreprocessPipelineFromConfig() {
if (!(this->disable_normalize_and_permute)) {
processors_.push_back(std::make_shared<HWC2CHW>());
}
// Fusion will improve performance
FuseTransforms(&processors_);
return true;
}
@@ -332,7 +337,6 @@ bool PaddleSegModel::Postprocess(
bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) {
Mat mat(*im);
std::vector<FDTensor> processed_data(1);
std::map<std::string, std::array<int, 2>> im_info;
@@ -340,18 +344,18 @@ bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) {
im_info["input_shape"] = {static_cast<int>(mat.Height()),
static_cast<int>(mat.Width())};
if (!Preprocess(&mat, &(processed_data[0]))) {
if (!Preprocess(&mat, &(reused_input_tensors_[0]))) {
FDERROR << "Failed to preprocess input data while using model:"
<< ModelName() << "." << std::endl;
return false;
}
std::vector<FDTensor> infer_result(1);
if (!Infer(processed_data, &infer_result)) {
if (!Infer()) {
FDERROR << "Failed to inference while using model:" << ModelName() << "."
<< std::endl;
return false;
}
if (!Postprocess(&infer_result[0], result, im_info)) {
if (!Postprocess(&reused_output_tensors_[0], result, im_info)) {
FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
<< std::endl;
return false;

View File

@@ -35,6 +35,13 @@ class PaddleClasPreprocessor:
"""
return self._preprocessor.run(input_ims)
def use_gpu(self, gpu_id=-1):
"""Use CUDA preprocessors
:param: gpu_id: GPU device id
"""
return self._preprocessor.use_gpu(gpu_id)
class PaddleClasPostprocessor:
def __init__(self, topk=1):