mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Backend] cuda normalize and permute, cuda concat, optimized ppcls, ppdet & ppseg (#546)
* cuda normalize and permute, cuda concat * add use cuda option for preprocessor * ppyoloe use cuda normalize * ppseg use cuda normalize * add proclib cuda in processor base * ppcls add use cuda preprocess api * ppcls preprocessor set gpu id * fix pybind * refine ppcls preprocessing use gpu logic * fdtensor device id is -1 by default * refine assert message Co-authored-by: heliqi <1101791222@qq.com>
This commit is contained in:
@@ -252,7 +252,8 @@ void FDTensor::FreeFn() {
|
||||
}
|
||||
}
|
||||
|
||||
void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes) {
|
||||
void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes,
|
||||
const Device& device, bool is_pinned_memory) {
|
||||
if (device == Device::GPU) {
|
||||
#ifdef WITH_GPU
|
||||
FDASSERT(cudaMemcpy(dst, src, nbytes, cudaMemcpyDeviceToDevice) == 0,
|
||||
@@ -295,7 +296,7 @@ FDTensor::FDTensor(const FDTensor& other)
|
||||
size_t nbytes = Nbytes();
|
||||
FDASSERT(ReallocFn(nbytes),
|
||||
"The FastDeploy FDTensor allocate memory error");
|
||||
CopyBuffer(buffer_, other.buffer_, nbytes);
|
||||
CopyBuffer(buffer_, other.buffer_, nbytes, device, is_pinned_memory);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -325,7 +326,7 @@ FDTensor& FDTensor::operator=(const FDTensor& other) {
|
||||
} else {
|
||||
Resize(other.shape, other.dtype, other.name, other.device);
|
||||
size_t nbytes = Nbytes();
|
||||
CopyBuffer(buffer_, other.buffer_, nbytes);
|
||||
CopyBuffer(buffer_, other.buffer_, nbytes, device, is_pinned_memory);
|
||||
}
|
||||
external_data_ptr = other.external_data_ptr;
|
||||
}
|
||||
|
@@ -39,6 +39,9 @@ struct FASTDEPLOY_DECL FDTensor {
|
||||
// GPU to inference the model
|
||||
// so we can skip data transfer, which may improve the efficience
|
||||
Device device = Device::CPU;
|
||||
// By default the device id of FDTensor is -1, which means this value is
|
||||
// invalid, and FDTensor is using the same device id as Runtime.
|
||||
int device_id = -1;
|
||||
|
||||
// Whether the data buffer is in pinned memory, which is allocated
|
||||
// with cudaMallocHost()
|
||||
@@ -130,8 +133,9 @@ struct FASTDEPLOY_DECL FDTensor {
|
||||
|
||||
~FDTensor() { FreeFn(); }
|
||||
|
||||
private:
|
||||
void CopyBuffer(void* dst, const void* src, size_t nbytes);
|
||||
static void CopyBuffer(void* dst, const void* src, size_t nbytes,
|
||||
const Device& device = Device::CPU,
|
||||
bool is_pinned_memory = false);
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -85,8 +85,9 @@ struct ConcatFunctor {
|
||||
int64_t col_len = input_cols[j];
|
||||
const T* input_data = reinterpret_cast<const T*>(input[j].Data());
|
||||
for (int64_t k = 0; k < out_rows; ++k) {
|
||||
std::memcpy(output_data + k * out_cols + col_idx,
|
||||
input_data + k * col_len, sizeof(T) * col_len);
|
||||
FDTensor::CopyBuffer(output_data + k * out_cols + col_idx,
|
||||
input_data + k * col_len, sizeof(T) * col_len,
|
||||
input[j].device, input[j].is_pinned_memory);
|
||||
}
|
||||
col_idx += col_len;
|
||||
}
|
||||
@@ -97,7 +98,8 @@ template <typename T>
|
||||
void ConcatKernel(const std::vector<FDTensor>& input, FDTensor* output,
|
||||
int axis) {
|
||||
auto output_shape = ComputeAndCheckConcatOutputShape(input, axis);
|
||||
output->Allocate(output_shape, TypeToDataType<T>::dtype);
|
||||
output->Resize(output_shape, TypeToDataType<T>::dtype, output->name,
|
||||
input[0].device);
|
||||
|
||||
ConcatFunctor<T> functor;
|
||||
functor(input, axis, output);
|
||||
@@ -115,10 +117,9 @@ void Concat(const std::vector<FDTensor>& x, FDTensor* out, int axis) {
|
||||
if (axis < 0) {
|
||||
axis += rank;
|
||||
}
|
||||
FDTensor out_temp;
|
||||
|
||||
FD_VISIT_ALL_TYPES(x[0].dtype, "Concat",
|
||||
([&] { ConcatKernel<data_t>(x, &out_temp, axis); }));
|
||||
*out = std::move(out_temp);
|
||||
([&] { ConcatKernel<data_t>(x, out, axis); }));
|
||||
}
|
||||
|
||||
} // namespace function
|
||||
|
@@ -568,6 +568,11 @@ std::vector<TensorInfo> Runtime::GetOutputInfos() {
|
||||
|
||||
bool Runtime::Infer(std::vector<FDTensor>& input_tensors,
|
||||
std::vector<FDTensor>* output_tensors) {
|
||||
for (auto& tensor: input_tensors) {
|
||||
FDASSERT(tensor.device_id < 0 || tensor.device_id == option.device_id,
|
||||
"Device id of input tensor(%d) and runtime(%d) are not same.",
|
||||
tensor.device_id, option.device_id);
|
||||
}
|
||||
return backend_->Infer(input_tensors, output_tensors);
|
||||
}
|
||||
|
||||
|
0
fastdeploy/vision/classification/ppcls/model.cc
Executable file → Normal file
0
fastdeploy/vision/classification/ppcls/model.cc
Executable file → Normal file
@@ -28,6 +28,9 @@ void BindPaddleClas(pybind11::module& m) {
|
||||
pybind11::eval("raise Exception('Failed to preprocess the input data in PaddleClasPreprocessor.')");
|
||||
}
|
||||
return outputs;
|
||||
})
|
||||
.def("use_gpu", [](vision::classification::PaddleClasPreprocessor& self, int gpu_id = -1) {
|
||||
self.UseGpu(gpu_id);
|
||||
});
|
||||
|
||||
pybind11::class_<vision::classification::PaddleClasPostprocessor>(
|
||||
|
@@ -15,17 +15,22 @@
|
||||
#include "fastdeploy/vision/classification/ppcls/preprocessor.h"
|
||||
#include "fastdeploy/function/concat.h"
|
||||
#include "yaml-cpp/yaml.h"
|
||||
#ifdef WITH_GPU
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace classification {
|
||||
|
||||
PaddleClasPreprocessor::PaddleClasPreprocessor(const std::string& config_file) {
|
||||
FDASSERT(BuildPreprocessPipelineFromConfig(config_file), "Failed to create PaddleClasPreprocessor.");
|
||||
FDASSERT(BuildPreprocessPipelineFromConfig(config_file),
|
||||
"Failed to create PaddleClasPreprocessor.");
|
||||
initialized_ = true;
|
||||
}
|
||||
|
||||
bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(const std::string& config_file) {
|
||||
bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(
|
||||
const std::string& config_file) {
|
||||
processors_.clear();
|
||||
YAML::Node cfg;
|
||||
try {
|
||||
@@ -73,6 +78,19 @@ bool PaddleClasPreprocessor::BuildPreprocessPipelineFromConfig(const std::string
|
||||
return true;
|
||||
}
|
||||
|
||||
void PaddleClasPreprocessor::UseGpu(int gpu_id) {
|
||||
#ifdef WITH_GPU
|
||||
use_cuda_ = true;
|
||||
if (gpu_id < 0) return;
|
||||
device_id_ = gpu_id;
|
||||
cudaSetDevice(device_id_);
|
||||
#else
|
||||
FDWARNING << "FastDeploy didn't compile with WITH_GPU. "
|
||||
<< "Will force to use CPU to run preprocessing." << std::endl;
|
||||
use_cuda_ = false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs) {
|
||||
if (!initialized_) {
|
||||
FDERROR << "The preprocessor is not initialized." << std::endl;
|
||||
@@ -85,8 +103,15 @@ bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTenso
|
||||
|
||||
for (size_t i = 0; i < images->size(); ++i) {
|
||||
for (size_t j = 0; j < processors_.size(); ++j) {
|
||||
if (!(*(processors_[j].get()))(&((*images)[i]))) {
|
||||
FDERROR << "Failed to processs image:" << i << " in " << processors_[i]->Name() << "." << std::endl;
|
||||
bool ret = false;
|
||||
if (processors_[j]->Name() == "NormalizeAndPermute" && use_cuda_) {
|
||||
ret = (*(processors_[j].get()))(&((*images)[i]), ProcLib::CUDA);
|
||||
} else {
|
||||
ret = (*(processors_[j].get()))(&((*images)[i]));
|
||||
}
|
||||
if (!ret) {
|
||||
FDERROR << "Failed to processs image:" << i << " in "
|
||||
<< processors_[i]->Name() << "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -104,6 +129,7 @@ bool PaddleClasPreprocessor::Run(std::vector<FDMat>* images, std::vector<FDTenso
|
||||
} else {
|
||||
function::Concat(tensors, &((*outputs)[0]), 0);
|
||||
}
|
||||
(*outputs)[0].device_id = device_id_;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@@ -38,11 +38,19 @@ class FASTDEPLOY_DECL PaddleClasPreprocessor {
|
||||
*/
|
||||
bool Run(std::vector<FDMat>* images, std::vector<FDTensor>* outputs);
|
||||
|
||||
/** \brief Use GPU to run preprocessing
|
||||
*
|
||||
* \param[in] gpu_id GPU device id
|
||||
*/
|
||||
void UseGpu(int gpu_id = -1);
|
||||
|
||||
private:
|
||||
bool BuildPreprocessPipelineFromConfig(const std::string& config_file);
|
||||
std::vector<std::shared_ptr<Processor>> processors_;
|
||||
bool initialized_ = false;
|
||||
bool use_cuda_ = false;
|
||||
// GPU device id
|
||||
int device_id_ = -1;
|
||||
};
|
||||
|
||||
} // namespace classification
|
||||
|
@@ -30,12 +30,32 @@ bool Processor::operator()(Mat* mat, ProcLib lib) {
|
||||
return ImplByFlyCV(mat);
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with FlyCV.");
|
||||
#endif
|
||||
} else if (target == ProcLib::CUDA) {
|
||||
#ifdef WITH_GPU
|
||||
return ImplByCuda(mat);
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with WITH_GPU.");
|
||||
#endif
|
||||
}
|
||||
// DEFAULT & OPENCV
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
FDTensor* Processor::UpdateAndGetReusedBuffer(
|
||||
const std::vector<int64_t>& new_shape, const int& opencv_dtype,
|
||||
const std::string& buffer_name, const Device& new_device,
|
||||
const bool& use_pinned_memory) {
|
||||
if (reused_buffers_.count(buffer_name) == 0) {
|
||||
reused_buffers_[buffer_name] = FDTensor();
|
||||
}
|
||||
reused_buffers_[buffer_name].is_pinned_memory = use_pinned_memory;
|
||||
reused_buffers_[buffer_name].Resize(new_shape,
|
||||
OpenCVDataTypeToFD(opencv_dtype),
|
||||
buffer_name, new_device);
|
||||
return &reused_buffers_[buffer_name];
|
||||
}
|
||||
|
||||
void EnableFlyCV() {
|
||||
#ifdef ENABLE_FLYCV
|
||||
DefaultProcLib::default_lib = ProcLib::FLYCV;
|
||||
|
@@ -18,6 +18,7 @@
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
#include "opencv2/highgui/highgui.hpp"
|
||||
#include "opencv2/imgproc/imgproc.hpp"
|
||||
#include <unordered_map>
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
@@ -55,7 +56,20 @@ class FASTDEPLOY_DECL Processor {
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
virtual bool ImplByCuda(Mat* mat) {
|
||||
return ImplByOpenCV(mat);
|
||||
}
|
||||
|
||||
virtual bool operator()(Mat* mat, ProcLib lib = ProcLib::DEFAULT);
|
||||
|
||||
protected:
|
||||
FDTensor* UpdateAndGetReusedBuffer(
|
||||
const std::vector<int64_t>& new_shape, const int& opencv_dtype,
|
||||
const std::string& buffer_name, const Device& new_device = Device::CPU,
|
||||
const bool& use_pinned_memory = false);
|
||||
|
||||
private:
|
||||
std::unordered_map<std::string, FDTensor> reused_buffers_;
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
|
@@ -34,7 +34,7 @@ void* Mat::Data() {
|
||||
|
||||
void Mat::ShareWithTensor(FDTensor* tensor) {
|
||||
tensor->SetExternalData({Channels(), Height(), Width()}, Type(), Data());
|
||||
tensor->device = Device::CPU;
|
||||
tensor->device = device;
|
||||
if (layout == Layout::HWC) {
|
||||
tensor->shape = {Height(), Width(), Channels()};
|
||||
}
|
||||
|
@@ -131,6 +131,7 @@ struct FASTDEPLOY_DECL Mat {
|
||||
|
||||
ProcLib mat_type = ProcLib::OPENCV;
|
||||
Layout layout = Layout::HWC;
|
||||
Device device = Device::CPU;
|
||||
|
||||
// Create FD Mat from FD Tensor. This method only create a
|
||||
// new FD Mat with zero copy and it's data pointer is reference
|
||||
|
@@ -73,7 +73,6 @@ bool NormalizeAndPermute::ImplByOpenCV(Mat* mat) {
|
||||
res.ptr() + i * origin_h * origin_w * 4),
|
||||
0);
|
||||
}
|
||||
|
||||
mat->SetMat(res);
|
||||
mat->layout = Layout::CHW;
|
||||
return true;
|
||||
|
82
fastdeploy/vision/common/processors/normalize_and_permute.cu
Normal file
82
fastdeploy/vision/common/processors/normalize_and_permute.cu
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/common/processors/normalize_and_permute.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
__global__ void NormalizeAndPermuteKernel(
|
||||
uint8_t* src, float* dst, const float* alpha, const float* beta,
|
||||
int num_channel, bool swap_rb, int edge) {
|
||||
int idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (idx >= edge) return;
|
||||
|
||||
if (swap_rb) {
|
||||
uint8_t tmp = src[num_channel * idx];
|
||||
src[num_channel * idx] = src[num_channel * idx + 2];
|
||||
src[num_channel * idx + 2] = tmp;
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_channel; ++i) {
|
||||
dst[idx + edge * i] = src[num_channel * idx + i] * alpha[i] + beta[i];
|
||||
}
|
||||
}
|
||||
|
||||
bool NormalizeAndPermute::ImplByCuda(Mat* mat) {
|
||||
cv::Mat* im = mat->GetOpenCVMat();
|
||||
std::string buf_name = Name() + "_src";
|
||||
std::vector<int64_t> shape = {im->rows, im->cols, im->channels()};
|
||||
FDTensor* src = UpdateAndGetReusedBuffer(shape, im->type(), buf_name,
|
||||
Device::GPU);
|
||||
FDASSERT(cudaMemcpy(src->Data(), im->ptr(), src->Nbytes(),
|
||||
cudaMemcpyHostToDevice) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
|
||||
buf_name = Name() + "_dst";
|
||||
FDTensor* dst = UpdateAndGetReusedBuffer(shape, CV_32FC(im->channels()),
|
||||
buf_name, Device::GPU);
|
||||
cv::Mat res(im->rows, im->cols, CV_32FC(im->channels()), dst->Data());
|
||||
|
||||
buf_name = Name() + "_alpha";
|
||||
FDTensor* alpha = UpdateAndGetReusedBuffer({(int64_t)alpha_.size()}, CV_32FC1,
|
||||
buf_name, Device::GPU);
|
||||
FDASSERT(cudaMemcpy(alpha->Data(), alpha_.data(), alpha->Nbytes(),
|
||||
cudaMemcpyHostToDevice) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
|
||||
buf_name = Name() + "_beta";
|
||||
FDTensor* beta = UpdateAndGetReusedBuffer({(int64_t)beta_.size()}, CV_32FC1,
|
||||
buf_name, Device::GPU);
|
||||
FDASSERT(cudaMemcpy(beta->Data(), beta_.data(), beta->Nbytes(),
|
||||
cudaMemcpyHostToDevice) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
|
||||
int jobs = im->cols * im->rows;
|
||||
int threads = 256;
|
||||
int blocks = ceil(jobs / (float)threads);
|
||||
NormalizeAndPermuteKernel<<<blocks, threads, 0, NULL>>>(
|
||||
reinterpret_cast<uint8_t*>(src->Data()),
|
||||
reinterpret_cast<float*>(dst->Data()),
|
||||
reinterpret_cast<float*>(alpha->Data()),
|
||||
reinterpret_cast<float*>(beta->Data()), im->channels(), swap_rb_, jobs);
|
||||
|
||||
mat->SetMat(res);
|
||||
mat->device = Device::GPU;
|
||||
mat->layout = Layout::CHW;
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
@@ -28,6 +28,9 @@ class FASTDEPLOY_DECL NormalizeAndPermute : public Processor {
|
||||
bool ImplByOpenCV(Mat* mat);
|
||||
#ifdef ENABLE_FLYCV
|
||||
bool ImplByFlyCV(Mat* mat);
|
||||
#endif
|
||||
#ifdef WITH_GPU
|
||||
bool ImplByCuda(Mat* mat);
|
||||
#endif
|
||||
std::string Name() { return "NormalizeAndPermute"; }
|
||||
|
||||
|
@@ -30,6 +30,9 @@ std::ostream& operator<<(std::ostream& out, const ProcLib& p) {
|
||||
case ProcLib::FLYCV:
|
||||
out << "ProcLib::FLYCV";
|
||||
break;
|
||||
case ProcLib::CUDA:
|
||||
out << "ProcLib::CUDA";
|
||||
break;
|
||||
default:
|
||||
FDASSERT(false, "Unknow type of ProcLib.");
|
||||
}
|
||||
|
@@ -18,7 +18,7 @@
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV };
|
||||
enum class FASTDEPLOY_DECL ProcLib { DEFAULT, OPENCV, FLYCV, CUDA };
|
||||
|
||||
FASTDEPLOY_DECL std::ostream& operator<<(std::ostream& out, const ProcLib& p);
|
||||
|
||||
|
@@ -146,7 +146,6 @@ bool PPYOLOE::BuildPreprocessPipelineFromConfig() {
|
||||
|
||||
// Fusion will improve performance
|
||||
FuseTransforms(&processors_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -161,8 +160,6 @@ bool PPYOLOE::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
|
||||
}
|
||||
}
|
||||
|
||||
Cast::Run(mat, "float");
|
||||
|
||||
outputs->resize(2);
|
||||
(*outputs)[0].name = InputInfoOfRuntime(0).name;
|
||||
mat->ShareWithTensor(&((*outputs)[0]));
|
||||
|
@@ -39,6 +39,8 @@ PaddleSegModel::PaddleSegModel(const std::string& model_file,
|
||||
}
|
||||
|
||||
bool PaddleSegModel::Initialize() {
|
||||
reused_input_tensors_.resize(1);
|
||||
reused_output_tensors_.resize(1);
|
||||
if (!BuildPreprocessPipelineFromConfig()) {
|
||||
FDERROR << "Failed to build preprocess pipeline from configuration file."
|
||||
<< std::endl;
|
||||
@@ -133,6 +135,9 @@ bool PaddleSegModel::BuildPreprocessPipelineFromConfig() {
|
||||
if (!(this->disable_normalize_and_permute)) {
|
||||
processors_.push_back(std::make_shared<HWC2CHW>());
|
||||
}
|
||||
|
||||
// Fusion will improve performance
|
||||
FuseTransforms(&processors_);
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -332,7 +337,6 @@ bool PaddleSegModel::Postprocess(
|
||||
|
||||
bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) {
|
||||
Mat mat(*im);
|
||||
std::vector<FDTensor> processed_data(1);
|
||||
|
||||
std::map<std::string, std::array<int, 2>> im_info;
|
||||
|
||||
@@ -340,18 +344,18 @@ bool PaddleSegModel::Predict(cv::Mat* im, SegmentationResult* result) {
|
||||
im_info["input_shape"] = {static_cast<int>(mat.Height()),
|
||||
static_cast<int>(mat.Width())};
|
||||
|
||||
if (!Preprocess(&mat, &(processed_data[0]))) {
|
||||
if (!Preprocess(&mat, &(reused_input_tensors_[0]))) {
|
||||
FDERROR << "Failed to preprocess input data while using model:"
|
||||
<< ModelName() << "." << std::endl;
|
||||
return false;
|
||||
}
|
||||
std::vector<FDTensor> infer_result(1);
|
||||
if (!Infer(processed_data, &infer_result)) {
|
||||
|
||||
if (!Infer()) {
|
||||
FDERROR << "Failed to inference while using model:" << ModelName() << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
if (!Postprocess(&infer_result[0], result, im_info)) {
|
||||
if (!Postprocess(&reused_output_tensors_[0], result, im_info)) {
|
||||
FDERROR << "Failed to postprocess while using model:" << ModelName() << "."
|
||||
<< std::endl;
|
||||
return false;
|
||||
|
@@ -35,6 +35,13 @@ class PaddleClasPreprocessor:
|
||||
"""
|
||||
return self._preprocessor.run(input_ims)
|
||||
|
||||
def use_gpu(self, gpu_id=-1):
|
||||
"""Use CUDA preprocessors
|
||||
|
||||
:param: gpu_id: GPU device id
|
||||
"""
|
||||
return self._preprocessor.use_gpu(gpu_id)
|
||||
|
||||
|
||||
class PaddleClasPostprocessor:
|
||||
def __init__(self, topk=1):
|
||||
|
Reference in New Issue
Block a user