mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[nvJPEG] Integrate nvJPEG decoder (#1288)
* nvjpeg cmake * add common decoder, nvjpeg decoder and add image name predict api * ppclas support nvjpeg decoder * remove useless comments * image decoder support opencv * nvjpeg decode fallback to opencv * fdtensor add nbytes_allocated * single image decode api * fix bug * add pybind * ignore nvjpeg on jetson * fix cmake in * predict on fdmat * remove image names predict api, add image decoder tutorial * Update __init__.py * fix pybind
This commit is contained in:
@@ -300,10 +300,16 @@ if(WITH_GPU)
|
||||
include_directories(${CUDA_DIRECTORY}/include)
|
||||
if(WIN32)
|
||||
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib/x64)
|
||||
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib/x64)
|
||||
add_definitions(-DENABLE_NVJPEG)
|
||||
else()
|
||||
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64)
|
||||
if(NOT BUILD_ON_JETSON)
|
||||
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib64)
|
||||
add_definitions(-DENABLE_NVJPEG)
|
||||
endif()
|
||||
list(APPEND DEPEND_LIBS ${CUDA_LIB})
|
||||
endif()
|
||||
list(APPEND DEPEND_LIBS ${CUDA_LIB} ${NVJPEG_LIB})
|
||||
|
||||
# build CUDA source files in fastdeploy, CUDA source files include CUDA preprocessing, TRT plugins, etc.
|
||||
enable_language(CUDA)
|
||||
|
@@ -169,21 +169,25 @@ if(ENABLE_POROS_BACKEND)
|
||||
endif()
|
||||
|
||||
if(WITH_GPU)
|
||||
if (NOT CUDA_DIRECTORY)
|
||||
if(NOT CUDA_DIRECTORY)
|
||||
set(CUDA_DIRECTORY "/usr/local/cuda")
|
||||
endif()
|
||||
if(WIN32)
|
||||
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib/x64)
|
||||
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib/x64)
|
||||
else()
|
||||
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64)
|
||||
if(NOT BUILD_ON_JETSON)
|
||||
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib64)
|
||||
endif()
|
||||
endif()
|
||||
if(NOT CUDA_LIB)
|
||||
message(FATAL_ERROR "[FastDeploy] Cannot find library cudart in ${CUDA_DIRECTORY}, Please define CUDA_DIRECTORY, e.g -DCUDA_DIRECTORY=/path/to/cuda")
|
||||
endif()
|
||||
list(APPEND FASTDEPLOY_LIBS ${CUDA_LIB})
|
||||
list(APPEND FASTDEPLOY_LIBS ${CUDA_LIB} ${NVJPEG_LIB})
|
||||
list(APPEND FASTDEPLOY_INCS ${CUDA_DIRECTORY}/include)
|
||||
|
||||
if (ENABLE_TRT_BACKEND)
|
||||
if(ENABLE_TRT_BACKEND)
|
||||
if(BUILD_ON_JETSON)
|
||||
find_library(TRT_INFER_LIB nvinfer /usr/lib/aarch64-linux-gnu/)
|
||||
find_library(TRT_ONNX_LIB nvonnxparser /usr/lib/aarch64-linux-gnu/)
|
||||
|
@@ -245,12 +245,13 @@ void FDTensor::PrintInfo(const std::string& prefix) const {
|
||||
bool FDTensor::ReallocFn(size_t nbytes) {
|
||||
if (device == Device::GPU) {
|
||||
#ifdef WITH_GPU
|
||||
size_t original_nbytes = Nbytes();
|
||||
size_t original_nbytes = nbytes_allocated;
|
||||
if (nbytes > original_nbytes) {
|
||||
if (buffer_ != nullptr) {
|
||||
FDDeviceFree()(buffer_);
|
||||
}
|
||||
FDDeviceAllocator()(&buffer_, nbytes);
|
||||
nbytes_allocated = nbytes;
|
||||
}
|
||||
return buffer_ != nullptr;
|
||||
#else
|
||||
@@ -262,12 +263,13 @@ bool FDTensor::ReallocFn(size_t nbytes) {
|
||||
} else {
|
||||
if (is_pinned_memory) {
|
||||
#ifdef WITH_GPU
|
||||
size_t original_nbytes = Nbytes();
|
||||
size_t original_nbytes = nbytes_allocated;
|
||||
if (nbytes > original_nbytes) {
|
||||
if (buffer_ != nullptr) {
|
||||
FDDeviceHostFree()(buffer_);
|
||||
}
|
||||
FDDeviceHostAllocator()(&buffer_, nbytes);
|
||||
nbytes_allocated = nbytes;
|
||||
}
|
||||
return buffer_ != nullptr;
|
||||
#else
|
||||
@@ -278,6 +280,7 @@ bool FDTensor::ReallocFn(size_t nbytes) {
|
||||
#endif
|
||||
}
|
||||
buffer_ = realloc(buffer_, nbytes);
|
||||
nbytes_allocated = nbytes;
|
||||
return buffer_ != nullptr;
|
||||
}
|
||||
}
|
||||
@@ -299,6 +302,7 @@ void FDTensor::FreeFn() {
|
||||
}
|
||||
}
|
||||
buffer_ = nullptr;
|
||||
nbytes_allocated = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -380,7 +384,7 @@ FDTensor::FDTensor(const FDTensor& other)
|
||||
device_id(other.device_id) {
|
||||
// Copy buffer
|
||||
if (other.buffer_ == nullptr) {
|
||||
buffer_ = nullptr;
|
||||
FreeFn();
|
||||
} else {
|
||||
size_t nbytes = Nbytes();
|
||||
FDASSERT(ReallocFn(nbytes),
|
||||
@@ -396,7 +400,8 @@ FDTensor::FDTensor(FDTensor&& other)
|
||||
dtype(other.dtype),
|
||||
external_data_ptr(other.external_data_ptr),
|
||||
device(other.device),
|
||||
device_id(other.device_id) {
|
||||
device_id(other.device_id),
|
||||
nbytes_allocated(other.nbytes_allocated) {
|
||||
other.name = "";
|
||||
// Note(zhoushunjie): Avoid double free.
|
||||
other.buffer_ = nullptr;
|
||||
@@ -435,6 +440,7 @@ FDTensor& FDTensor::operator=(FDTensor&& other) {
|
||||
dtype = other.dtype;
|
||||
device = other.device;
|
||||
device_id = other.device_id;
|
||||
nbytes_allocated = other.nbytes_allocated;
|
||||
|
||||
other.name = "";
|
||||
// Note(zhoushunjie): Avoid double free.
|
||||
|
@@ -54,6 +54,11 @@ struct FASTDEPLOY_DECL FDTensor {
|
||||
// other devices' data
|
||||
std::vector<int8_t> temporary_cpu_buffer;
|
||||
|
||||
// The number of bytes allocated so far.
|
||||
// When resizing GPU memory, we will free and realloc the memory only if the
|
||||
// required size is larger than this value.
|
||||
size_t nbytes_allocated = 0;
|
||||
|
||||
// Get data buffer pointer
|
||||
void* MutableData();
|
||||
|
||||
|
36
fastdeploy/vision/classification/ppcls/model.cc
Executable file → Normal file
36
fastdeploy/vision/classification/ppcls/model.cc
Executable file → Normal file
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/classification/ppcls/model.h"
|
||||
|
||||
#include "fastdeploy/utils/unique_ptr.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
@@ -23,7 +24,8 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
|
||||
const std::string& params_file,
|
||||
const std::string& config_file,
|
||||
const RuntimeOption& custom_option,
|
||||
const ModelFormat& model_format) : preprocessor_(config_file) {
|
||||
const ModelFormat& model_format)
|
||||
: preprocessor_(config_file) {
|
||||
if (model_format == ModelFormat::PADDLE) {
|
||||
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT,
|
||||
Backend::LITE};
|
||||
@@ -32,10 +34,9 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
|
||||
valid_ascend_backends = {Backend::LITE};
|
||||
valid_kunlunxin_backends = {Backend::LITE};
|
||||
valid_ipu_backends = {Backend::PDINFER};
|
||||
}else if (model_format == ModelFormat::SOPHGO) {
|
||||
} else if (model_format == ModelFormat::SOPHGO) {
|
||||
valid_sophgonpu_backends = {Backend::SOPHGOTPU};
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::TRT};
|
||||
valid_rknpu_backends = {Backend::RKNPU2};
|
||||
@@ -49,7 +50,8 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
|
||||
}
|
||||
|
||||
std::unique_ptr<PaddleClasModel> PaddleClasModel::Clone() const {
|
||||
std::unique_ptr<PaddleClasModel> clone_model = utils::make_unique<PaddleClasModel>(PaddleClasModel(*this));
|
||||
std::unique_ptr<PaddleClasModel> clone_model =
|
||||
utils::make_unique<PaddleClasModel>(PaddleClasModel(*this));
|
||||
clone_model->SetRuntime(clone_model->CloneRuntime());
|
||||
return clone_model;
|
||||
}
|
||||
@@ -71,17 +73,30 @@ bool PaddleClasModel::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
|
||||
}
|
||||
|
||||
bool PaddleClasModel::Predict(const cv::Mat& im, ClassifyResult* result) {
|
||||
FDMat mat = WrapMat(im);
|
||||
return Predict(mat, result);
|
||||
}
|
||||
|
||||
bool PaddleClasModel::BatchPredict(const std::vector<cv::Mat>& images,
|
||||
std::vector<ClassifyResult>* results) {
|
||||
std::vector<FDMat> mats = WrapMat(images);
|
||||
return BatchPredict(mats, results);
|
||||
}
|
||||
|
||||
bool PaddleClasModel::Predict(const FDMat& mat, ClassifyResult* result) {
|
||||
std::vector<ClassifyResult> results;
|
||||
if (!BatchPredict({im}, &results)) {
|
||||
std::vector<FDMat> mats = {mat};
|
||||
if (!BatchPredict(mats, &results)) {
|
||||
return false;
|
||||
}
|
||||
*result = std::move(results[0]);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool PaddleClasModel::BatchPredict(const std::vector<cv::Mat>& images, std::vector<ClassifyResult>* results) {
|
||||
std::vector<FDMat> fd_images = WrapMat(images);
|
||||
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) {
|
||||
bool PaddleClasModel::BatchPredict(const std::vector<FDMat>& mats,
|
||||
std::vector<ClassifyResult>* results) {
|
||||
std::vector<FDMat> fd_mats = mats;
|
||||
if (!preprocessor_.Run(&fd_mats, &reused_input_tensors_)) {
|
||||
FDERROR << "Failed to preprocess the input image." << std::endl;
|
||||
return false;
|
||||
}
|
||||
@@ -92,7 +107,8 @@ bool PaddleClasModel::BatchPredict(const std::vector<cv::Mat>& images, std::vect
|
||||
}
|
||||
|
||||
if (!postprocessor_.Run(reused_output_tensors_, results)) {
|
||||
FDERROR << "Failed to postprocess the inference results by runtime." << std::endl;
|
||||
FDERROR << "Failed to postprocess the inference results by runtime."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@@ -75,6 +75,23 @@ class FASTDEPLOY_DECL PaddleClasModel : public FastDeployModel {
|
||||
virtual bool BatchPredict(const std::vector<cv::Mat>& imgs,
|
||||
std::vector<ClassifyResult>* results);
|
||||
|
||||
/** \brief Predict the classification result for an input image
|
||||
*
|
||||
* \param[in] mat The input mat
|
||||
* \param[in] result The output classification result
|
||||
* \return true if the prediction successed, otherwise false
|
||||
*/
|
||||
virtual bool Predict(const FDMat& mat, ClassifyResult* result);
|
||||
|
||||
/** \brief Predict the classification results for a batch of input images
|
||||
*
|
||||
* \param[in] mats, The input mat list
|
||||
* \param[in] results The output classification result list
|
||||
* \return true if the prediction successed, otherwise false
|
||||
*/
|
||||
virtual bool BatchPredict(const std::vector<FDMat>& mats,
|
||||
std::vector<ClassifyResult>* results);
|
||||
|
||||
/// Get preprocessor reference of PaddleClasModel
|
||||
virtual PaddleClasPreprocessor& GetPreprocessor() {
|
||||
return preprocessor_;
|
||||
|
112
fastdeploy/vision/common/image_decoder/image_decoder.cc
Normal file
112
fastdeploy/vision/common/image_decoder/image_decoder.cc
Normal file
@@ -0,0 +1,112 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/vision/common/image_decoder/image_decoder.h"
|
||||
|
||||
#include "opencv2/imgcodecs.hpp"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
ImageDecoder::ImageDecoder(ImageDecoderLib lib) {
|
||||
if (lib == ImageDecoderLib::NVJPEG) {
|
||||
#ifdef ENABLE_NVJPEG
|
||||
nvjpeg::init_decoder(nvjpeg_params_);
|
||||
#endif
|
||||
}
|
||||
lib_ = lib;
|
||||
}
|
||||
|
||||
ImageDecoder::~ImageDecoder() {
|
||||
if (lib_ == ImageDecoderLib::NVJPEG) {
|
||||
#ifdef ENABLE_NVJPEG
|
||||
nvjpeg::destroy_decoder(nvjpeg_params_);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
bool ImageDecoder::Decode(const std::string& img_name, FDMat* mat) {
|
||||
std::vector<FDMat> mats(1);
|
||||
mats[0] = std::move(*mat);
|
||||
if (!BatchDecode({img_name}, &mats)) {
|
||||
return false;
|
||||
}
|
||||
*mat = std::move(mats[0]);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ImageDecoder::BatchDecode(const std::vector<std::string>& img_names,
|
||||
std::vector<FDMat>* mats) {
|
||||
if (lib_ == ImageDecoderLib::OPENCV) {
|
||||
return ImplByOpenCV(img_names, mats);
|
||||
} else if (lib_ == ImageDecoderLib::NVJPEG) {
|
||||
return ImplByNvJpeg(img_names, mats);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ImageDecoder::ImplByOpenCV(const std::vector<std::string>& img_names,
|
||||
std::vector<FDMat>* mats) {
|
||||
for (size_t i = 0; i < img_names.size(); ++i) {
|
||||
cv::Mat im = cv::imread(img_names[i]);
|
||||
(*mats)[i].SetMat(im);
|
||||
(*mats)[i].layout = Layout::HWC;
|
||||
(*mats)[i].SetWidth(im.cols);
|
||||
(*mats)[i].SetHeight(im.rows);
|
||||
(*mats)[i].SetChannels(im.channels());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ImageDecoder::ImplByNvJpeg(const std::vector<std::string>& img_names,
|
||||
std::vector<FDMat>* mats) {
|
||||
#ifdef ENABLE_NVJPEG
|
||||
nvjpeg_params_.batch_size = img_names.size();
|
||||
std::vector<nvjpegImage_t> output_imgs(nvjpeg_params_.batch_size);
|
||||
std::vector<int> widths(nvjpeg_params_.batch_size);
|
||||
std::vector<int> heights(nvjpeg_params_.batch_size);
|
||||
// TODO(wangxinyu): support other output format
|
||||
nvjpeg_params_.fmt = NVJPEG_OUTPUT_BGRI;
|
||||
double total;
|
||||
nvjpeg_params_.stream = (*mats)[0].Stream();
|
||||
|
||||
std::vector<FDTensor*> output_buffers;
|
||||
for (size_t i = 0; i < mats->size(); ++i) {
|
||||
FDASSERT((*mats)[i].output_cache != nullptr,
|
||||
"The output_cache of FDMat was not set.");
|
||||
output_buffers.push_back((*mats)[i].output_cache);
|
||||
}
|
||||
|
||||
if (nvjpeg::process_images(img_names, nvjpeg_params_, total, output_imgs,
|
||||
output_buffers, widths, heights)) {
|
||||
// If nvJPEG decode failed, will fallback to OpenCV,
|
||||
// e.g. png format is not supported by nvJPEG
|
||||
FDWARNING << "nvJPEG decode failed, falling back to OpenCV for this batch"
|
||||
<< std::endl;
|
||||
return ImplByOpenCV(img_names, mats);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < mats->size(); ++i) {
|
||||
(*mats)[i].mat_type = ProcLib::CUDA;
|
||||
(*mats)[i].layout = Layout::HWC;
|
||||
(*mats)[i].SetTensor(output_buffers[i]);
|
||||
}
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compile with NVJPEG.");
|
||||
#endif
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
49
fastdeploy/vision/common/image_decoder/image_decoder.h
Normal file
49
fastdeploy/vision/common/image_decoder/image_decoder.h
Normal file
@@ -0,0 +1,49 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include "fastdeploy/vision/common/processors/mat.h"
|
||||
#include "fastdeploy/vision/common/image_decoder/nvjpeg_decoder.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
|
||||
enum class FASTDEPLOY_DECL ImageDecoderLib { OPENCV, NVJPEG };
|
||||
|
||||
class FASTDEPLOY_DECL ImageDecoder {
|
||||
public:
|
||||
explicit ImageDecoder(ImageDecoderLib lib = ImageDecoderLib::OPENCV);
|
||||
|
||||
~ImageDecoder();
|
||||
|
||||
bool Decode(const std::string& img_name, FDMat* mat);
|
||||
|
||||
bool BatchDecode(const std::vector<std::string>& img_names,
|
||||
std::vector<FDMat>* mats);
|
||||
|
||||
private:
|
||||
bool ImplByOpenCV(const std::vector<std::string>& img_names,
|
||||
std::vector<FDMat>* mats);
|
||||
bool ImplByNvJpeg(const std::vector<std::string>& img_names,
|
||||
std::vector<FDMat>* mats);
|
||||
ImageDecoderLib lib_ = ImageDecoderLib::OPENCV;
|
||||
#ifdef ENABLE_NVJPEG
|
||||
nvjpeg::decode_params_t nvjpeg_params_;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
363
fastdeploy/vision/common/image_decoder/nvjpeg_decoder.cc
Normal file
363
fastdeploy/vision/common/image_decoder/nvjpeg_decoder.cc
Normal file
@@ -0,0 +1,363 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Part of the following code in this file refs to
|
||||
// https://github.com/CVCUDA/CV-CUDA/blob/release_v0.2.x/samples/common/NvDecoder.cpp
|
||||
//
|
||||
// Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
// Licensed under the Apache-2.0 license
|
||||
// \brief
|
||||
// \author NVIDIA
|
||||
|
||||
#ifdef ENABLE_NVJPEG
|
||||
#include "fastdeploy/vision/common/image_decoder/nvjpeg_decoder.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace nvjpeg {
|
||||
|
||||
#define CHECK_CUDA(call) \
|
||||
{ \
|
||||
cudaError_t _e = (call); \
|
||||
if (_e != cudaSuccess) { \
|
||||
std::cout << "CUDA Runtime failure: '#" << _e << "' at " << __FILE__ \
|
||||
<< ":" << __LINE__ << std::endl; \
|
||||
exit(1); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define CHECK_NVJPEG(call) \
|
||||
{ \
|
||||
nvjpegStatus_t _e = (call); \
|
||||
if (_e != NVJPEG_STATUS_SUCCESS) { \
|
||||
std::cout << "NVJPEG failure: '#" << _e << "' at " << __FILE__ << ":" \
|
||||
<< __LINE__ << std::endl; \
|
||||
exit(1); \
|
||||
} \
|
||||
}
|
||||
|
||||
static int dev_malloc(void** p, size_t s) { return (int)cudaMalloc(p, s); }
|
||||
|
||||
static int dev_free(void* p) { return (int)cudaFree(p); }
|
||||
|
||||
static int host_malloc(void** p, size_t s, unsigned int f) {
|
||||
return (int)cudaHostAlloc(p, s, f);
|
||||
}
|
||||
|
||||
static int host_free(void* p) { return (int)cudaFreeHost(p); }
|
||||
|
||||
static int read_images(const FileNames& image_names, FileData& raw_data,
|
||||
std::vector<size_t>& raw_len) {
|
||||
for (size_t i = 0; i < image_names.size(); ++i) {
|
||||
if (image_names.size() == 0) {
|
||||
std::cerr << "No valid images left in the input list, exit" << std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
// Read an image from disk.
|
||||
std::ifstream input(image_names[i].c_str(),
|
||||
std::ios::in | std::ios::binary | std::ios::ate);
|
||||
if (!(input.is_open())) {
|
||||
std::cerr << "Cannot open image: " << image_names[i] << std::endl;
|
||||
FDASSERT(false, "Read file error.");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get the size
|
||||
long unsigned int file_size = input.tellg();
|
||||
input.seekg(0, std::ios::beg);
|
||||
// resize if buffer is too small
|
||||
if (raw_data[i].size() < file_size) {
|
||||
raw_data[i].resize(file_size);
|
||||
}
|
||||
if (!input.read(raw_data[i].data(), file_size)) {
|
||||
std::cerr << "Cannot read from file: " << image_names[i] << std::endl;
|
||||
// image_names.erase(cur_iter);
|
||||
FDASSERT(false, "Read file error.");
|
||||
continue;
|
||||
}
|
||||
raw_len[i] = file_size;
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
// prepare buffers for RGBi output format
|
||||
static int prepare_buffers(FileData& file_data, std::vector<size_t>& file_len,
|
||||
std::vector<int>& img_width,
|
||||
std::vector<int>& img_height,
|
||||
std::vector<nvjpegImage_t>& ibuf,
|
||||
std::vector<nvjpegImage_t>& isz,
|
||||
std::vector<FDTensor*>& output_buffers,
|
||||
const FileNames& current_names,
|
||||
decode_params_t& params) {
|
||||
int widths[NVJPEG_MAX_COMPONENT];
|
||||
int heights[NVJPEG_MAX_COMPONENT];
|
||||
int channels;
|
||||
nvjpegChromaSubsampling_t subsampling;
|
||||
|
||||
for (long unsigned int i = 0; i < file_data.size(); i++) {
|
||||
nvjpegStatus_t status = nvjpegGetImageInfo(
|
||||
params.nvjpeg_handle, (unsigned char*)file_data[i].data(), file_len[i],
|
||||
&channels, &subsampling, widths, heights);
|
||||
if (status != NVJPEG_STATUS_SUCCESS) {
|
||||
std::cout << "NVJPEG failure: #" << status << " in nvjpegGetImageInfo."
|
||||
<< std::endl;
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
img_width[i] = widths[0];
|
||||
img_height[i] = heights[0];
|
||||
|
||||
int mul = 1;
|
||||
// in the case of interleaved RGB output, write only to single channel, but
|
||||
// 3 samples at once
|
||||
if (params.fmt == NVJPEG_OUTPUT_RGBI || params.fmt == NVJPEG_OUTPUT_BGRI) {
|
||||
channels = 1;
|
||||
mul = 3;
|
||||
} else if (params.fmt == NVJPEG_OUTPUT_RGB ||
|
||||
params.fmt == NVJPEG_OUTPUT_BGR) {
|
||||
// in the case of rgb create 3 buffers with sizes of original image
|
||||
channels = 3;
|
||||
widths[1] = widths[2] = widths[0];
|
||||
heights[1] = heights[2] = heights[0];
|
||||
} else {
|
||||
FDASSERT(false, "Unsupport NVJPEG output format: %d", params.fmt);
|
||||
}
|
||||
|
||||
output_buffers[i]->Resize({heights[0], widths[0], mul * channels},
|
||||
FDDataType::UINT8, "output_cache", Device::GPU);
|
||||
|
||||
uint8_t* cur_buffer = reinterpret_cast<uint8_t*>(output_buffers[i]->Data());
|
||||
|
||||
// realloc output buffer if required
|
||||
for (int c = 0; c < channels; c++) {
|
||||
int aw = mul * widths[c];
|
||||
int ah = heights[c];
|
||||
size_t sz = aw * ah;
|
||||
ibuf[i].pitch[c] = aw;
|
||||
if (sz > isz[i].pitch[c]) {
|
||||
ibuf[i].channel[c] = cur_buffer;
|
||||
cur_buffer = cur_buffer + sz;
|
||||
isz[i].pitch[c] = sz;
|
||||
}
|
||||
}
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
static void create_decoupled_api_handles(decode_params_t& params) {
|
||||
CHECK_NVJPEG(nvjpegDecoderCreate(params.nvjpeg_handle, NVJPEG_BACKEND_DEFAULT,
|
||||
¶ms.nvjpeg_decoder));
|
||||
CHECK_NVJPEG(nvjpegDecoderStateCreate(params.nvjpeg_handle,
|
||||
params.nvjpeg_decoder,
|
||||
¶ms.nvjpeg_decoupled_state));
|
||||
|
||||
CHECK_NVJPEG(nvjpegBufferPinnedCreate(params.nvjpeg_handle, NULL,
|
||||
¶ms.pinned_buffers[0]));
|
||||
CHECK_NVJPEG(nvjpegBufferPinnedCreate(params.nvjpeg_handle, NULL,
|
||||
¶ms.pinned_buffers[1]));
|
||||
CHECK_NVJPEG(nvjpegBufferDeviceCreate(params.nvjpeg_handle, NULL,
|
||||
¶ms.device_buffer));
|
||||
|
||||
CHECK_NVJPEG(
|
||||
nvjpegJpegStreamCreate(params.nvjpeg_handle, ¶ms.jpeg_streams[0]));
|
||||
CHECK_NVJPEG(
|
||||
nvjpegJpegStreamCreate(params.nvjpeg_handle, ¶ms.jpeg_streams[1]));
|
||||
|
||||
CHECK_NVJPEG(nvjpegDecodeParamsCreate(params.nvjpeg_handle,
|
||||
¶ms.nvjpeg_decode_params));
|
||||
}
|
||||
|
||||
static void destroy_decoupled_api_handles(decode_params_t& params) {
|
||||
CHECK_NVJPEG(nvjpegDecodeParamsDestroy(params.nvjpeg_decode_params));
|
||||
CHECK_NVJPEG(nvjpegJpegStreamDestroy(params.jpeg_streams[0]));
|
||||
CHECK_NVJPEG(nvjpegJpegStreamDestroy(params.jpeg_streams[1]));
|
||||
CHECK_NVJPEG(nvjpegBufferPinnedDestroy(params.pinned_buffers[0]));
|
||||
CHECK_NVJPEG(nvjpegBufferPinnedDestroy(params.pinned_buffers[1]));
|
||||
CHECK_NVJPEG(nvjpegBufferDeviceDestroy(params.device_buffer));
|
||||
CHECK_NVJPEG(nvjpegJpegStateDestroy(params.nvjpeg_decoupled_state));
|
||||
CHECK_NVJPEG(nvjpegDecoderDestroy(params.nvjpeg_decoder));
|
||||
}
|
||||
|
||||
int decode_images(const FileData& img_data, const std::vector<size_t>& img_len,
|
||||
std::vector<nvjpegImage_t>& out, decode_params_t& params,
|
||||
double& time) {
|
||||
CHECK_CUDA(cudaStreamSynchronize(params.stream));
|
||||
|
||||
std::vector<const unsigned char*> batched_bitstreams;
|
||||
std::vector<size_t> batched_bitstreams_size;
|
||||
std::vector<nvjpegImage_t> batched_output;
|
||||
|
||||
// bit-streams that batched decode cannot handle
|
||||
std::vector<const unsigned char*> otherdecode_bitstreams;
|
||||
std::vector<size_t> otherdecode_bitstreams_size;
|
||||
std::vector<nvjpegImage_t> otherdecode_output;
|
||||
|
||||
if (params.hw_decode_available) {
|
||||
for (int i = 0; i < params.batch_size; i++) {
|
||||
// extract bitstream meta data to figure out whether a bit-stream can be
|
||||
// decoded
|
||||
nvjpegJpegStreamParseHeader(params.nvjpeg_handle,
|
||||
(const unsigned char*)img_data[i].data(),
|
||||
img_len[i], params.jpeg_streams[0]);
|
||||
int isSupported = -1;
|
||||
nvjpegDecodeBatchedSupported(params.nvjpeg_handle, params.jpeg_streams[0],
|
||||
&isSupported);
|
||||
|
||||
if (isSupported == 0) {
|
||||
batched_bitstreams.push_back((const unsigned char*)img_data[i].data());
|
||||
batched_bitstreams_size.push_back(img_len[i]);
|
||||
batched_output.push_back(out[i]);
|
||||
} else {
|
||||
otherdecode_bitstreams.push_back(
|
||||
(const unsigned char*)img_data[i].data());
|
||||
otherdecode_bitstreams_size.push_back(img_len[i]);
|
||||
otherdecode_output.push_back(out[i]);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < params.batch_size; i++) {
|
||||
otherdecode_bitstreams.push_back(
|
||||
(const unsigned char*)img_data[i].data());
|
||||
otherdecode_bitstreams_size.push_back(img_len[i]);
|
||||
otherdecode_output.push_back(out[i]);
|
||||
}
|
||||
}
|
||||
|
||||
if (batched_bitstreams.size() > 0) {
|
||||
CHECK_NVJPEG(nvjpegDecodeBatchedInitialize(
|
||||
params.nvjpeg_handle, params.nvjpeg_state, batched_bitstreams.size(), 1,
|
||||
params.fmt));
|
||||
|
||||
CHECK_NVJPEG(nvjpegDecodeBatched(
|
||||
params.nvjpeg_handle, params.nvjpeg_state, batched_bitstreams.data(),
|
||||
batched_bitstreams_size.data(), batched_output.data(), params.stream));
|
||||
}
|
||||
|
||||
if (otherdecode_bitstreams.size() > 0) {
|
||||
CHECK_NVJPEG(nvjpegStateAttachDeviceBuffer(params.nvjpeg_decoupled_state,
|
||||
params.device_buffer));
|
||||
int buffer_index = 0;
|
||||
CHECK_NVJPEG(nvjpegDecodeParamsSetOutputFormat(params.nvjpeg_decode_params,
|
||||
params.fmt));
|
||||
for (int i = 0; i < params.batch_size; i++) {
|
||||
CHECK_NVJPEG(nvjpegJpegStreamParse(params.nvjpeg_handle,
|
||||
otherdecode_bitstreams[i],
|
||||
otherdecode_bitstreams_size[i], 0, 0,
|
||||
params.jpeg_streams[buffer_index]));
|
||||
|
||||
CHECK_NVJPEG(nvjpegStateAttachPinnedBuffer(
|
||||
params.nvjpeg_decoupled_state, params.pinned_buffers[buffer_index]));
|
||||
|
||||
CHECK_NVJPEG(nvjpegDecodeJpegHost(
|
||||
params.nvjpeg_handle, params.nvjpeg_decoder,
|
||||
params.nvjpeg_decoupled_state, params.nvjpeg_decode_params,
|
||||
params.jpeg_streams[buffer_index]));
|
||||
|
||||
CHECK_CUDA(cudaStreamSynchronize(params.stream));
|
||||
|
||||
CHECK_NVJPEG(nvjpegDecodeJpegTransferToDevice(
|
||||
params.nvjpeg_handle, params.nvjpeg_decoder,
|
||||
params.nvjpeg_decoupled_state, params.jpeg_streams[buffer_index],
|
||||
params.stream));
|
||||
|
||||
buffer_index = 1 - buffer_index; // switch pinned buffer in pipeline mode
|
||||
// to avoid an extra sync
|
||||
|
||||
CHECK_NVJPEG(
|
||||
nvjpegDecodeJpegDevice(params.nvjpeg_handle, params.nvjpeg_decoder,
|
||||
params.nvjpeg_decoupled_state,
|
||||
&otherdecode_output[i], params.stream));
|
||||
}
|
||||
}
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
double process_images(const FileNames& image_names, decode_params_t& params,
|
||||
double& total, std::vector<nvjpegImage_t>& iout,
|
||||
std::vector<FDTensor*>& output_buffers,
|
||||
std::vector<int>& widths, std::vector<int>& heights) {
|
||||
FDASSERT(image_names.size() == params.batch_size,
|
||||
"Number of images and batch size must be equal.");
|
||||
// vector for storing raw files and file lengths
|
||||
FileData file_data(params.batch_size);
|
||||
std::vector<size_t> file_len(params.batch_size);
|
||||
FileNames current_names(params.batch_size);
|
||||
// we wrap over image files to process total_images of files
|
||||
auto file_iter = image_names.begin();
|
||||
|
||||
// output buffer sizes, for convenience
|
||||
std::vector<nvjpegImage_t> isz(params.batch_size);
|
||||
|
||||
for (long unsigned int i = 0; i < iout.size(); i++) {
|
||||
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
|
||||
iout[i].channel[c] = NULL;
|
||||
iout[i].pitch[c] = 0;
|
||||
isz[i].pitch[c] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
if (read_images(image_names, file_data, file_len)) {
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
if (prepare_buffers(file_data, file_len, widths, heights, iout, isz,
|
||||
output_buffers, image_names, params)) {
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
double time;
|
||||
if (decode_images(file_data, file_len, iout, params, time)) {
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
|
||||
return EXIT_SUCCESS;
|
||||
}
|
||||
|
||||
void init_decoder(decode_params_t& params) {
|
||||
params.hw_decode_available = true;
|
||||
nvjpegDevAllocator_t dev_allocator = {&dev_malloc, &dev_free};
|
||||
nvjpegPinnedAllocator_t pinned_allocator = {&host_malloc, &host_free};
|
||||
nvjpegStatus_t status =
|
||||
nvjpegCreateEx(NVJPEG_BACKEND_HARDWARE, &dev_allocator, &pinned_allocator,
|
||||
NVJPEG_FLAGS_DEFAULT, ¶ms.nvjpeg_handle);
|
||||
if (status == NVJPEG_STATUS_ARCH_MISMATCH) {
|
||||
std::cout << "Hardware Decoder not supported. "
|
||||
"Falling back to default backend"
|
||||
<< std::endl;
|
||||
CHECK_NVJPEG(nvjpegCreateEx(NVJPEG_BACKEND_DEFAULT, &dev_allocator,
|
||||
&pinned_allocator, NVJPEG_FLAGS_DEFAULT,
|
||||
¶ms.nvjpeg_handle));
|
||||
params.hw_decode_available = false;
|
||||
} else {
|
||||
CHECK_NVJPEG(status);
|
||||
}
|
||||
|
||||
CHECK_NVJPEG(
|
||||
nvjpegJpegStateCreate(params.nvjpeg_handle, ¶ms.nvjpeg_state));
|
||||
|
||||
create_decoupled_api_handles(params);
|
||||
}
|
||||
|
||||
void destroy_decoder(decode_params_t& params) {
|
||||
destroy_decoupled_api_handles(params);
|
||||
CHECK_NVJPEG(nvjpegJpegStateDestroy(params.nvjpeg_state));
|
||||
CHECK_NVJPEG(nvjpegDestroy(params.nvjpeg_handle));
|
||||
}
|
||||
|
||||
} // namespace nvjpeg
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
||||
#endif // ENABLE_NVJPEG
|
69
fastdeploy/vision/common/image_decoder/nvjpeg_decoder.h
Normal file
69
fastdeploy/vision/common/image_decoder/nvjpeg_decoder.h
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Part of the following code in this file refs to
|
||||
// https://github.com/CVCUDA/CV-CUDA/blob/release_v0.2.x/samples/common/NvDecoder.h
|
||||
//
|
||||
// Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
// Licensed under the Apache-2.0 license
|
||||
// \brief
|
||||
// \author NVIDIA
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef ENABLE_NVJPEG
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <nvjpeg.h>
|
||||
|
||||
|
||||
namespace fastdeploy {
|
||||
namespace vision {
|
||||
namespace nvjpeg {
|
||||
|
||||
typedef std::vector<std::string> FileNames;
|
||||
typedef std::vector<std::vector<char>> FileData;
|
||||
|
||||
struct decode_params_t {
|
||||
int batch_size;
|
||||
nvjpegJpegState_t nvjpeg_state;
|
||||
nvjpegHandle_t nvjpeg_handle;
|
||||
cudaStream_t stream;
|
||||
|
||||
// used with decoupled API
|
||||
nvjpegJpegState_t nvjpeg_decoupled_state;
|
||||
nvjpegBufferPinned_t pinned_buffers[2]; // 2 buffers for pipelining
|
||||
nvjpegBufferDevice_t device_buffer;
|
||||
nvjpegJpegStream_t jpeg_streams[2]; // 2 streams for pipelining
|
||||
nvjpegDecodeParams_t nvjpeg_decode_params;
|
||||
nvjpegJpegDecoder_t nvjpeg_decoder;
|
||||
|
||||
nvjpegOutputFormat_t fmt;
|
||||
bool hw_decode_available;
|
||||
};
|
||||
|
||||
void init_decoder(decode_params_t& params);
|
||||
void destroy_decoder(decode_params_t& params);
|
||||
|
||||
double process_images(const FileNames& image_names, decode_params_t& params,
|
||||
double& total, std::vector<nvjpegImage_t>& iout,
|
||||
std::vector<FDTensor*>& output_buffers,
|
||||
std::vector<int>& widths, std::vector<int>& heights);
|
||||
|
||||
} // namespace nvjpeg
|
||||
} // namespace vision
|
||||
} // namespace fastdeploy
|
||||
|
||||
#endif // ENABLE_NVJPEG
|
@@ -77,6 +77,16 @@ bool ProcessorManager::Run(std::vector<FDMat>* images,
|
||||
}
|
||||
(*images)[i].input_cache = &input_caches_[i];
|
||||
(*images)[i].output_cache = &output_caches_[i];
|
||||
if ((*images)[i].mat_type == ProcLib::CUDA) {
|
||||
// Make a copy of the input data ptr, so that the original data ptr of
|
||||
// FDMat won't be modified.
|
||||
auto fd_tensor = std::make_shared<FDTensor>();
|
||||
fd_tensor->SetExternalData(
|
||||
(*images)[i].Tensor()->shape, (*images)[i].Tensor()->Dtype(),
|
||||
(*images)[i].Tensor()->Data(), (*images)[i].Tensor()->device,
|
||||
(*images)[i].Tensor()->device_id);
|
||||
(*images)[i].SetTensor(fd_tensor);
|
||||
}
|
||||
}
|
||||
|
||||
bool ret = Apply(&image_batch, outputs);
|
||||
|
@@ -35,6 +35,10 @@ class FASTDEPLOY_DECL ProcessorManager {
|
||||
|
||||
bool CudaUsed();
|
||||
|
||||
#ifdef WITH_GPU
|
||||
cudaStream_t Stream() const { return stream_; }
|
||||
#endif
|
||||
|
||||
void SetStream(FDMat* mat) {
|
||||
#ifdef WITH_GPU
|
||||
mat->SetStream(stream_);
|
||||
@@ -56,7 +60,7 @@ class FASTDEPLOY_DECL ProcessorManager {
|
||||
|
||||
int DeviceId() { return device_id_; }
|
||||
|
||||
/** \brief Process the input image and prepare input tensors for runtime
|
||||
/** \brief Process the input images and prepare input tensors for runtime
|
||||
*
|
||||
* \param[in] images The input image data list, all the elements are returned by cv::imread()
|
||||
* \param[in] outputs The output tensors which will feed in runtime
|
||||
|
@@ -37,7 +37,7 @@ cv::Mat* Mat::GetOpenCVMat() {
|
||||
#ifdef WITH_GPU
|
||||
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
|
||||
"[ERROR] Error occurs while sync cuda stream.");
|
||||
cpu_mat = CreateZeroCopyOpenCVMatFromTensor(fd_tensor);
|
||||
cpu_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor);
|
||||
mat_type = ProcLib::OPENCV;
|
||||
device = Device::CPU;
|
||||
return &cpu_mat;
|
||||
@@ -59,29 +59,53 @@ void* Mat::Data() {
|
||||
"fcv::Mat.");
|
||||
#endif
|
||||
} else if (device == Device::GPU) {
|
||||
return fd_tensor.Data();
|
||||
return fd_tensor->Data();
|
||||
}
|
||||
return cpu_mat.ptr();
|
||||
}
|
||||
|
||||
FDTensor* Mat::Tensor() {
|
||||
if (mat_type == ProcLib::OPENCV) {
|
||||
ShareWithTensor(&fd_tensor);
|
||||
ShareWithTensor(fd_tensor.get());
|
||||
} else if (mat_type == ProcLib::FLYCV) {
|
||||
#ifdef ENABLE_FLYCV
|
||||
cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat);
|
||||
mat_type = ProcLib::OPENCV;
|
||||
ShareWithTensor(&fd_tensor);
|
||||
ShareWithTensor(fd_tensor.get());
|
||||
#else
|
||||
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
|
||||
#endif
|
||||
}
|
||||
return &fd_tensor;
|
||||
return fd_tensor.get();
|
||||
}
|
||||
|
||||
void Mat::SetTensor(FDTensor* tensor) {
|
||||
fd_tensor.SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
|
||||
fd_tensor->SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
|
||||
tensor->device, tensor->device_id);
|
||||
device = tensor->device;
|
||||
if (layout == Layout::HWC) {
|
||||
height = tensor->Shape()[0];
|
||||
width = tensor->Shape()[1];
|
||||
channels = tensor->Shape()[2];
|
||||
} else if (layout == Layout::CHW) {
|
||||
channels = tensor->Shape()[0];
|
||||
height = tensor->Shape()[1];
|
||||
width = tensor->Shape()[2];
|
||||
}
|
||||
}
|
||||
|
||||
void Mat::SetTensor(std::shared_ptr<FDTensor>& tensor) {
|
||||
fd_tensor = tensor;
|
||||
device = tensor->device;
|
||||
if (layout == Layout::HWC) {
|
||||
height = tensor->Shape()[0];
|
||||
width = tensor->Shape()[1];
|
||||
channels = tensor->Shape()[2];
|
||||
} else if (layout == Layout::CHW) {
|
||||
channels = tensor->Shape()[0];
|
||||
height = tensor->Shape()[1];
|
||||
width = tensor->Shape()[2];
|
||||
}
|
||||
}
|
||||
|
||||
void Mat::ShareWithTensor(FDTensor* tensor) {
|
||||
@@ -134,7 +158,7 @@ void Mat::PrintInfo(const std::string& flag) {
|
||||
#ifdef WITH_GPU
|
||||
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
|
||||
"[ERROR] Error occurs while sync cuda stream.");
|
||||
cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(fd_tensor);
|
||||
cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor);
|
||||
cv::Scalar mean = cv::mean(tmp_mat);
|
||||
for (int i = 0; i < Channels(); ++i) {
|
||||
std::cout << mean[i] << " ";
|
||||
@@ -157,7 +181,7 @@ FDDataType Mat::Type() {
|
||||
"fcv::Mat.");
|
||||
#endif
|
||||
} else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) {
|
||||
return fd_tensor.Dtype();
|
||||
return fd_tensor->Dtype();
|
||||
}
|
||||
return OpenCVDataTypeToFD(cpu_mat.type());
|
||||
}
|
||||
@@ -262,6 +286,10 @@ FDTensor* CreateCachedGpuInputTensor(Mat* mat) {
|
||||
#ifdef WITH_GPU
|
||||
FDTensor* src = mat->Tensor();
|
||||
if (src->device == Device::GPU) {
|
||||
if (src->Data() == mat->output_cache->Data()) {
|
||||
std::swap(mat->input_cache, mat->output_cache);
|
||||
std::swap(mat->input_cache->name, mat->output_cache->name);
|
||||
}
|
||||
return src;
|
||||
} else if (src->device == Device::CPU) {
|
||||
// Mats on CPU, we need copy these tensors from CPU to GPU
|
||||
|
@@ -49,7 +49,6 @@ struct FASTDEPLOY_DECL Mat {
|
||||
#endif
|
||||
|
||||
Mat(const Mat& mat) = default;
|
||||
// Move assignment
|
||||
Mat& operator=(const Mat& mat) = default;
|
||||
|
||||
// Move constructor
|
||||
@@ -96,6 +95,8 @@ struct FASTDEPLOY_DECL Mat {
|
||||
// Set fd_tensor
|
||||
void SetTensor(FDTensor* tensor);
|
||||
|
||||
void SetTensor(std::shared_ptr<FDTensor>& tensor);
|
||||
|
||||
private:
|
||||
int channels;
|
||||
int height;
|
||||
@@ -109,7 +110,7 @@ struct FASTDEPLOY_DECL Mat {
|
||||
#endif
|
||||
// Currently, fd_tensor is only used by CUDA and CV-CUDA,
|
||||
// OpenCV and FlyCV are not using it.
|
||||
FDTensor fd_tensor;
|
||||
std::shared_ptr<FDTensor> fd_tensor = std::make_shared<FDTensor>();
|
||||
|
||||
public:
|
||||
FDDataType Type();
|
||||
|
@@ -27,7 +27,7 @@ void FDMatBatch::SetStream(cudaStream_t s) {
|
||||
|
||||
FDTensor* FDMatBatch::Tensor() {
|
||||
if (has_batched_tensor) {
|
||||
return &fd_tensor;
|
||||
return fd_tensor.get();
|
||||
}
|
||||
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.")
|
||||
// Each mat has its own tensor,
|
||||
@@ -45,11 +45,11 @@ FDTensor* FDMatBatch::Tensor() {
|
||||
num_bytes, device, false);
|
||||
}
|
||||
SetTensor(input_cache);
|
||||
return &fd_tensor;
|
||||
return fd_tensor.get();
|
||||
}
|
||||
|
||||
void FDMatBatch::SetTensor(FDTensor* tensor) {
|
||||
fd_tensor.SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
|
||||
fd_tensor->SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
|
||||
tensor->device, tensor->device_id);
|
||||
has_batched_tensor = true;
|
||||
}
|
||||
|
@@ -29,7 +29,7 @@ struct FASTDEPLOY_DECL FDMatBatch {
|
||||
// MatBatch is intialized with a list of mats,
|
||||
// the data is stored in the mats separately.
|
||||
// Call Tensor() function to get a batched 4-dimension tensor.
|
||||
explicit FDMatBatch(std::vector<Mat>* _mats) {
|
||||
explicit FDMatBatch(std::vector<FDMat>* _mats) {
|
||||
mats = _mats;
|
||||
layout = FDMatBatchLayout::NHWC;
|
||||
mat_type = ProcLib::OPENCV;
|
||||
@@ -44,7 +44,7 @@ struct FASTDEPLOY_DECL FDMatBatch {
|
||||
#ifdef WITH_GPU
|
||||
cudaStream_t stream = nullptr;
|
||||
#endif
|
||||
FDTensor fd_tensor;
|
||||
std::shared_ptr<FDTensor> fd_tensor = std::make_shared<FDTensor>();
|
||||
|
||||
public:
|
||||
// When using CV-CUDA/CUDA, please set input/output cache,
|
||||
|
@@ -81,7 +81,7 @@ bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) {
|
||||
|
||||
// Prepare output tensor
|
||||
mat_batch->output_cache->Resize(src->Shape(), FDDataType::FP32,
|
||||
"output_cache", Device::GPU);
|
||||
"batch_output_cache", Device::GPU);
|
||||
// NHWC -> NCHW
|
||||
std::swap(mat_batch->output_cache->shape[1],
|
||||
mat_batch->output_cache->shape[3]);
|
||||
|
@@ -1,13 +1,9 @@
|
||||
English | [中文](README_CN.md)
|
||||
|
||||
|
||||
# Tutorials
|
||||
|
||||
|
||||
|
||||
This directory provides some tutorials for FastDeploy. For other model deployment, please refer to the example [FastDeploy/examples](../examples) directly.
|
||||
|
||||
|
||||
- Intel independent graphics card/integrated graphics card deployment [see intel_gpu](intel_gpu)
|
||||
|
||||
- Model multithreaded call [see multi_thread](multi_thread)
|
||||
- Image decoding, including hardward decoding, e.g. nvJPEG [image_decoder](image_decoder)
|
||||
|
@@ -7,3 +7,4 @@
|
||||
|
||||
- Intel独立显卡/集成显卡部署 [见intel_gpu](intel_gpu)
|
||||
- 模型多线程调用 [见multi_thread](multi_thread)
|
||||
- 图片解码(含nvJPEG硬解码) [见image_decoder](image_decoder)
|
||||
|
16
tutorials/image_decoder/README.md
Normal file
16
tutorials/image_decoder/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
English | [中文](README_CN.md)
|
||||
|
||||
# Image Decoder
|
||||
|
||||
Currently, we support below image decoder libs:
|
||||
- OpenCV
|
||||
- nvJPEG (Needs NVIDIA GPU, doesn't support Jetson)
|
||||
|
||||
## Example
|
||||
|
||||
- [C++ Example](cpp)
|
||||
- Python API(WIP)
|
||||
|
||||
## nvJPEG vs. OpenCV performance benchmark
|
||||
|
||||
Refer to: https://github.com/PaddlePaddle/FastDeploy/pull/1288#issuecomment-1427749772
|
16
tutorials/image_decoder/README_CN.md
Normal file
16
tutorials/image_decoder/README_CN.md
Normal file
@@ -0,0 +1,16 @@
|
||||
简体中文 | [English](README.md)
|
||||
|
||||
# Image Decoder
|
||||
|
||||
图片解码库,目前支持以下图片解码库:
|
||||
- OpenCV
|
||||
- nvJPEG (依赖NVIDIA GPU,不支持Jetson)
|
||||
|
||||
## 示例代码
|
||||
|
||||
- [C++示例](cpp)
|
||||
- Python API仍在开发中...
|
||||
|
||||
## nvJPEG和OpenCV性能对比数据
|
||||
|
||||
参见:https://github.com/PaddlePaddle/FastDeploy/pull/1288#issuecomment-1427749772
|
11
tutorials/image_decoder/cpp/CMakeLists.txt
Normal file
11
tutorials/image_decoder/cpp/CMakeLists.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
PROJECT(image_decoder C CXX)
|
||||
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
|
||||
|
||||
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
|
||||
|
||||
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
|
||||
|
||||
include_directories(${FASTDEPLOY_INCS})
|
||||
|
||||
add_executable(image_decoder ${PROJECT_SOURCE_DIR}/main.cc)
|
||||
target_link_libraries(image_decoder ${FASTDEPLOY_LIBS})
|
22
tutorials/image_decoder/cpp/README.md
Normal file
22
tutorials/image_decoder/cpp/README.md
Normal file
@@ -0,0 +1,22 @@
|
||||
English | [中文](README_CN.md)
|
||||
|
||||
# Image Decoder C++ Example
|
||||
|
||||
1. [Build FastDeploy](../docs/cn/build_and_install) or download [FastDeploy prebuilt library](../docs/cn/build_and_install/download_prebuilt_libraries.md)
|
||||
|
||||
2. Build example
|
||||
```bash
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
# [PATH-TO-FASTDEPLOY] is the install directory of FastDeploy
|
||||
cmake .. -DFASTDEPLOY_INSTALL_DIR=[PATH-TO-FASTDEPLOY]
|
||||
make -j
|
||||
|
||||
# Download the test image
|
||||
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
|
||||
|
||||
# OpenCV decoder
|
||||
./image_decoder ILSVRC2012_val_00000010.jpeg 0
|
||||
# nvJPEG
|
||||
./image_decoder ILSVRC2012_val_00000010.jpeg 1
|
22
tutorials/image_decoder/cpp/README_CN.md
Normal file
22
tutorials/image_decoder/cpp/README_CN.md
Normal file
@@ -0,0 +1,22 @@
|
||||
简体中文 | [English](README.md)
|
||||
|
||||
# Image Decoder C++示例
|
||||
|
||||
1. [编译FastDeploy](../docs/cn/build_and_install), 或直接下载[FastDeploy预编译库](../docs/cn/build_and_install/download_prebuilt_libraries.md)
|
||||
|
||||
2. 编译示例
|
||||
```bash
|
||||
mkdir build
|
||||
cd build
|
||||
|
||||
# [PATH-TO-FASTDEPLOY]需替换为FastDeploy的安装路径
|
||||
cmake .. -DFASTDEPLOY_INSTALL_DIR=[PATH-TO-FASTDEPLOY]
|
||||
make -j
|
||||
|
||||
# 下载测试图片
|
||||
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
|
||||
|
||||
# OpenCV解码
|
||||
./image_decoder ILSVRC2012_val_00000010.jpeg 0
|
||||
# nvJPEG
|
||||
./image_decoder ILSVRC2012_val_00000010.jpeg 1
|
57
tutorials/image_decoder/cpp/main.cc
Normal file
57
tutorials/image_decoder/cpp/main.cc
Normal file
@@ -0,0 +1,57 @@
|
||||
#include "fastdeploy/vision/common/image_decoder/image_decoder.h"
|
||||
|
||||
namespace fdvis = fastdeploy::vision;
|
||||
namespace fd = fastdeploy;
|
||||
|
||||
void OpenCVImageDecode(const std::string& img_name) {
|
||||
fdvis::FDMat mat;
|
||||
auto img_decoder = new fdvis::ImageDecoder();
|
||||
img_decoder->Decode(img_name, &mat);
|
||||
mat.PrintInfo("");
|
||||
delete img_decoder;
|
||||
}
|
||||
|
||||
void NvJpegImageDecode(const std::string& img_name) {
|
||||
std::vector<fdvis::FDMat> mats(1);
|
||||
std::vector<fastdeploy::FDTensor> caches(1);
|
||||
|
||||
cudaStream_t stream;
|
||||
cudaStreamCreate(&stream);
|
||||
// For nvJPEG decoder, we need set stream and output cache for the FDMat
|
||||
for (size_t i = 0; i < mats.size(); i++) {
|
||||
mats[i].output_cache = &caches[i];
|
||||
mats[i].SetStream(stream);
|
||||
}
|
||||
auto img_decoder = new fdvis::ImageDecoder(fdvis::ImageDecoderLib::NVJPEG);
|
||||
|
||||
// This is batch decode API, for single image decode API,
|
||||
// please refer to OpenCVImageDecode()
|
||||
img_decoder->BatchDecode({img_name}, &mats);
|
||||
|
||||
for (size_t i = 0; i < mats.size(); i++) {
|
||||
std::cout << "Mat type: " << mats[i].mat_type << ", "
|
||||
<< "DataType=" << mats[i].Type() << ", "
|
||||
<< "Channel=" << mats[i].Channels() << ", "
|
||||
<< "Height=" << mats[i].Height() << ", "
|
||||
<< "Width=" << mats[i].Width() << std::endl;
|
||||
}
|
||||
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
if (argc < 3) {
|
||||
std::cout << "Usage: image_decoder path/to/image run_option, "
|
||||
"e.g ./image_decoder ./test.jpeg 0"
|
||||
<< std::endl;
|
||||
std::cout << "Run_option 0: OpenCV; 1: nvJPEG " << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (std::atoi(argv[2]) == 0) {
|
||||
OpenCVImageDecode(argv[1]);
|
||||
} else if (std::atoi(argv[2]) == 1) {
|
||||
NvJpegImageDecode(argv[1]);
|
||||
}
|
||||
return 0;
|
||||
}
|
Reference in New Issue
Block a user