[nvJPEG] Integrate nvJPEG decoder (#1288)

* nvjpeg cmake

* add common decoder, nvjpeg decoder and add image name predict api

* ppclas support nvjpeg decoder

* remove useless comments

* image decoder support opencv

* nvjpeg decode fallback to opencv

* fdtensor add nbytes_allocated

* single image decode api

* fix bug

* add pybind

* ignore nvjpeg on jetson

* fix cmake in

* predict on fdmat

* remove image names predict api, add image decoder tutorial

* Update __init__.py

* fix pybind
This commit is contained in:
Wang Xinyu
2023-02-17 10:27:05 +08:00
committed by GitHub
parent e3a7ab4c14
commit efa46563f3
25 changed files with 875 additions and 44 deletions

View File

@@ -300,10 +300,16 @@ if(WITH_GPU)
include_directories(${CUDA_DIRECTORY}/include) include_directories(${CUDA_DIRECTORY}/include)
if(WIN32) if(WIN32)
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib/x64) find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib/x64)
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib/x64)
add_definitions(-DENABLE_NVJPEG)
else() else()
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64) find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64)
if(NOT BUILD_ON_JETSON)
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib64)
add_definitions(-DENABLE_NVJPEG)
endif()
endif() endif()
list(APPEND DEPEND_LIBS ${CUDA_LIB}) list(APPEND DEPEND_LIBS ${CUDA_LIB} ${NVJPEG_LIB})
# build CUDA source files in fastdeploy, CUDA source files include CUDA preprocessing, TRT plugins, etc. # build CUDA source files in fastdeploy, CUDA source files include CUDA preprocessing, TRT plugins, etc.
enable_language(CUDA) enable_language(CUDA)

View File

@@ -169,21 +169,25 @@ if(ENABLE_POROS_BACKEND)
endif() endif()
if(WITH_GPU) if(WITH_GPU)
if (NOT CUDA_DIRECTORY) if(NOT CUDA_DIRECTORY)
set(CUDA_DIRECTORY "/usr/local/cuda") set(CUDA_DIRECTORY "/usr/local/cuda")
endif() endif()
if(WIN32) if(WIN32)
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib/x64) find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib/x64)
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib/x64)
else() else()
find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64) find_library(CUDA_LIB cudart ${CUDA_DIRECTORY}/lib64)
if(NOT BUILD_ON_JETSON)
find_library(NVJPEG_LIB nvjpeg ${CUDA_DIRECTORY}/lib64)
endif()
endif() endif()
if(NOT CUDA_LIB) if(NOT CUDA_LIB)
message(FATAL_ERROR "[FastDeploy] Cannot find library cudart in ${CUDA_DIRECTORY}, Please define CUDA_DIRECTORY, e.g -DCUDA_DIRECTORY=/path/to/cuda") message(FATAL_ERROR "[FastDeploy] Cannot find library cudart in ${CUDA_DIRECTORY}, Please define CUDA_DIRECTORY, e.g -DCUDA_DIRECTORY=/path/to/cuda")
endif() endif()
list(APPEND FASTDEPLOY_LIBS ${CUDA_LIB}) list(APPEND FASTDEPLOY_LIBS ${CUDA_LIB} ${NVJPEG_LIB})
list(APPEND FASTDEPLOY_INCS ${CUDA_DIRECTORY}/include) list(APPEND FASTDEPLOY_INCS ${CUDA_DIRECTORY}/include)
if (ENABLE_TRT_BACKEND) if(ENABLE_TRT_BACKEND)
if(BUILD_ON_JETSON) if(BUILD_ON_JETSON)
find_library(TRT_INFER_LIB nvinfer /usr/lib/aarch64-linux-gnu/) find_library(TRT_INFER_LIB nvinfer /usr/lib/aarch64-linux-gnu/)
find_library(TRT_ONNX_LIB nvonnxparser /usr/lib/aarch64-linux-gnu/) find_library(TRT_ONNX_LIB nvonnxparser /usr/lib/aarch64-linux-gnu/)

View File

@@ -245,12 +245,13 @@ void FDTensor::PrintInfo(const std::string& prefix) const {
bool FDTensor::ReallocFn(size_t nbytes) { bool FDTensor::ReallocFn(size_t nbytes) {
if (device == Device::GPU) { if (device == Device::GPU) {
#ifdef WITH_GPU #ifdef WITH_GPU
size_t original_nbytes = Nbytes(); size_t original_nbytes = nbytes_allocated;
if (nbytes > original_nbytes) { if (nbytes > original_nbytes) {
if (buffer_ != nullptr) { if (buffer_ != nullptr) {
FDDeviceFree()(buffer_); FDDeviceFree()(buffer_);
} }
FDDeviceAllocator()(&buffer_, nbytes); FDDeviceAllocator()(&buffer_, nbytes);
nbytes_allocated = nbytes;
} }
return buffer_ != nullptr; return buffer_ != nullptr;
#else #else
@@ -262,12 +263,13 @@ bool FDTensor::ReallocFn(size_t nbytes) {
} else { } else {
if (is_pinned_memory) { if (is_pinned_memory) {
#ifdef WITH_GPU #ifdef WITH_GPU
size_t original_nbytes = Nbytes(); size_t original_nbytes = nbytes_allocated;
if (nbytes > original_nbytes) { if (nbytes > original_nbytes) {
if (buffer_ != nullptr) { if (buffer_ != nullptr) {
FDDeviceHostFree()(buffer_); FDDeviceHostFree()(buffer_);
} }
FDDeviceHostAllocator()(&buffer_, nbytes); FDDeviceHostAllocator()(&buffer_, nbytes);
nbytes_allocated = nbytes;
} }
return buffer_ != nullptr; return buffer_ != nullptr;
#else #else
@@ -278,6 +280,7 @@ bool FDTensor::ReallocFn(size_t nbytes) {
#endif #endif
} }
buffer_ = realloc(buffer_, nbytes); buffer_ = realloc(buffer_, nbytes);
nbytes_allocated = nbytes;
return buffer_ != nullptr; return buffer_ != nullptr;
} }
} }
@@ -299,6 +302,7 @@ void FDTensor::FreeFn() {
} }
} }
buffer_ = nullptr; buffer_ = nullptr;
nbytes_allocated = 0;
} }
} }
@@ -380,7 +384,7 @@ FDTensor::FDTensor(const FDTensor& other)
device_id(other.device_id) { device_id(other.device_id) {
// Copy buffer // Copy buffer
if (other.buffer_ == nullptr) { if (other.buffer_ == nullptr) {
buffer_ = nullptr; FreeFn();
} else { } else {
size_t nbytes = Nbytes(); size_t nbytes = Nbytes();
FDASSERT(ReallocFn(nbytes), FDASSERT(ReallocFn(nbytes),
@@ -396,7 +400,8 @@ FDTensor::FDTensor(FDTensor&& other)
dtype(other.dtype), dtype(other.dtype),
external_data_ptr(other.external_data_ptr), external_data_ptr(other.external_data_ptr),
device(other.device), device(other.device),
device_id(other.device_id) { device_id(other.device_id),
nbytes_allocated(other.nbytes_allocated) {
other.name = ""; other.name = "";
// Note(zhoushunjie): Avoid double free. // Note(zhoushunjie): Avoid double free.
other.buffer_ = nullptr; other.buffer_ = nullptr;
@@ -435,6 +440,7 @@ FDTensor& FDTensor::operator=(FDTensor&& other) {
dtype = other.dtype; dtype = other.dtype;
device = other.device; device = other.device;
device_id = other.device_id; device_id = other.device_id;
nbytes_allocated = other.nbytes_allocated;
other.name = ""; other.name = "";
// Note(zhoushunjie): Avoid double free. // Note(zhoushunjie): Avoid double free.

View File

@@ -54,6 +54,11 @@ struct FASTDEPLOY_DECL FDTensor {
// other devices' data // other devices' data
std::vector<int8_t> temporary_cpu_buffer; std::vector<int8_t> temporary_cpu_buffer;
// The number of bytes allocated so far.
// When resizing GPU memory, we will free and realloc the memory only if the
// required size is larger than this value.
size_t nbytes_allocated = 0;
// Get data buffer pointer // Get data buffer pointer
void* MutableData(); void* MutableData();

40
fastdeploy/vision/classification/ppcls/model.cc Executable file → Normal file
View File

@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/vision/classification/ppcls/model.h" #include "fastdeploy/vision/classification/ppcls/model.h"
#include "fastdeploy/utils/unique_ptr.h" #include "fastdeploy/utils/unique_ptr.h"
namespace fastdeploy { namespace fastdeploy {
@@ -23,7 +24,8 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
const std::string& params_file, const std::string& params_file,
const std::string& config_file, const std::string& config_file,
const RuntimeOption& custom_option, const RuntimeOption& custom_option,
const ModelFormat& model_format) : preprocessor_(config_file) { const ModelFormat& model_format)
: preprocessor_(config_file) {
if (model_format == ModelFormat::PADDLE) { if (model_format == ModelFormat::PADDLE) {
valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT, valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT,
Backend::LITE}; Backend::LITE};
@@ -32,15 +34,14 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
valid_ascend_backends = {Backend::LITE}; valid_ascend_backends = {Backend::LITE};
valid_kunlunxin_backends = {Backend::LITE}; valid_kunlunxin_backends = {Backend::LITE};
valid_ipu_backends = {Backend::PDINFER}; valid_ipu_backends = {Backend::PDINFER};
}else if (model_format == ModelFormat::SOPHGO) { } else if (model_format == ModelFormat::SOPHGO) {
valid_sophgonpu_backends = {Backend::SOPHGOTPU}; valid_sophgonpu_backends = {Backend::SOPHGOTPU};
} } else {
else {
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO}; valid_cpu_backends = {Backend::ORT, Backend::OPENVINO};
valid_gpu_backends = {Backend::ORT, Backend::TRT}; valid_gpu_backends = {Backend::ORT, Backend::TRT};
valid_rknpu_backends = {Backend::RKNPU2}; valid_rknpu_backends = {Backend::RKNPU2};
} }
runtime_option = custom_option; runtime_option = custom_option;
runtime_option.model_format = model_format; runtime_option.model_format = model_format;
runtime_option.model_file = model_file; runtime_option.model_file = model_file;
@@ -48,8 +49,9 @@ PaddleClasModel::PaddleClasModel(const std::string& model_file,
initialized = Initialize(); initialized = Initialize();
} }
std::unique_ptr<PaddleClasModel> PaddleClasModel::Clone() const { std::unique_ptr<PaddleClasModel> PaddleClasModel::Clone() const {
std::unique_ptr<PaddleClasModel> clone_model = utils::make_unique<PaddleClasModel>(PaddleClasModel(*this)); std::unique_ptr<PaddleClasModel> clone_model =
utils::make_unique<PaddleClasModel>(PaddleClasModel(*this));
clone_model->SetRuntime(clone_model->CloneRuntime()); clone_model->SetRuntime(clone_model->CloneRuntime());
return clone_model; return clone_model;
} }
@@ -71,17 +73,30 @@ bool PaddleClasModel::Predict(cv::Mat* im, ClassifyResult* result, int topk) {
} }
bool PaddleClasModel::Predict(const cv::Mat& im, ClassifyResult* result) { bool PaddleClasModel::Predict(const cv::Mat& im, ClassifyResult* result) {
FDMat mat = WrapMat(im);
return Predict(mat, result);
}
bool PaddleClasModel::BatchPredict(const std::vector<cv::Mat>& images,
std::vector<ClassifyResult>* results) {
std::vector<FDMat> mats = WrapMat(images);
return BatchPredict(mats, results);
}
bool PaddleClasModel::Predict(const FDMat& mat, ClassifyResult* result) {
std::vector<ClassifyResult> results; std::vector<ClassifyResult> results;
if (!BatchPredict({im}, &results)) { std::vector<FDMat> mats = {mat};
if (!BatchPredict(mats, &results)) {
return false; return false;
} }
*result = std::move(results[0]); *result = std::move(results[0]);
return true; return true;
} }
bool PaddleClasModel::BatchPredict(const std::vector<cv::Mat>& images, std::vector<ClassifyResult>* results) { bool PaddleClasModel::BatchPredict(const std::vector<FDMat>& mats,
std::vector<FDMat> fd_images = WrapMat(images); std::vector<ClassifyResult>* results) {
if (!preprocessor_.Run(&fd_images, &reused_input_tensors_)) { std::vector<FDMat> fd_mats = mats;
if (!preprocessor_.Run(&fd_mats, &reused_input_tensors_)) {
FDERROR << "Failed to preprocess the input image." << std::endl; FDERROR << "Failed to preprocess the input image." << std::endl;
return false; return false;
} }
@@ -92,7 +107,8 @@ bool PaddleClasModel::BatchPredict(const std::vector<cv::Mat>& images, std::vect
} }
if (!postprocessor_.Run(reused_output_tensors_, results)) { if (!postprocessor_.Run(reused_output_tensors_, results)) {
FDERROR << "Failed to postprocess the inference results by runtime." << std::endl; FDERROR << "Failed to postprocess the inference results by runtime."
<< std::endl;
return false; return false;
} }

View File

@@ -75,6 +75,23 @@ class FASTDEPLOY_DECL PaddleClasModel : public FastDeployModel {
virtual bool BatchPredict(const std::vector<cv::Mat>& imgs, virtual bool BatchPredict(const std::vector<cv::Mat>& imgs,
std::vector<ClassifyResult>* results); std::vector<ClassifyResult>* results);
/** \brief Predict the classification result for an input image
*
* \param[in] mat The input mat
* \param[in] result The output classification result
* \return true if the prediction successed, otherwise false
*/
virtual bool Predict(const FDMat& mat, ClassifyResult* result);
/** \brief Predict the classification results for a batch of input images
*
* \param[in] mats, The input mat list
* \param[in] results The output classification result list
* \return true if the prediction successed, otherwise false
*/
virtual bool BatchPredict(const std::vector<FDMat>& mats,
std::vector<ClassifyResult>* results);
/// Get preprocessor reference of PaddleClasModel /// Get preprocessor reference of PaddleClasModel
virtual PaddleClasPreprocessor& GetPreprocessor() { virtual PaddleClasPreprocessor& GetPreprocessor() {
return preprocessor_; return preprocessor_;

View File

@@ -0,0 +1,112 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/common/image_decoder/image_decoder.h"
#include "opencv2/imgcodecs.hpp"
namespace fastdeploy {
namespace vision {
ImageDecoder::ImageDecoder(ImageDecoderLib lib) {
if (lib == ImageDecoderLib::NVJPEG) {
#ifdef ENABLE_NVJPEG
nvjpeg::init_decoder(nvjpeg_params_);
#endif
}
lib_ = lib;
}
ImageDecoder::~ImageDecoder() {
if (lib_ == ImageDecoderLib::NVJPEG) {
#ifdef ENABLE_NVJPEG
nvjpeg::destroy_decoder(nvjpeg_params_);
#endif
}
}
bool ImageDecoder::Decode(const std::string& img_name, FDMat* mat) {
std::vector<FDMat> mats(1);
mats[0] = std::move(*mat);
if (!BatchDecode({img_name}, &mats)) {
return false;
}
*mat = std::move(mats[0]);
return true;
}
bool ImageDecoder::BatchDecode(const std::vector<std::string>& img_names,
std::vector<FDMat>* mats) {
if (lib_ == ImageDecoderLib::OPENCV) {
return ImplByOpenCV(img_names, mats);
} else if (lib_ == ImageDecoderLib::NVJPEG) {
return ImplByNvJpeg(img_names, mats);
}
return true;
}
bool ImageDecoder::ImplByOpenCV(const std::vector<std::string>& img_names,
std::vector<FDMat>* mats) {
for (size_t i = 0; i < img_names.size(); ++i) {
cv::Mat im = cv::imread(img_names[i]);
(*mats)[i].SetMat(im);
(*mats)[i].layout = Layout::HWC;
(*mats)[i].SetWidth(im.cols);
(*mats)[i].SetHeight(im.rows);
(*mats)[i].SetChannels(im.channels());
}
return true;
}
bool ImageDecoder::ImplByNvJpeg(const std::vector<std::string>& img_names,
std::vector<FDMat>* mats) {
#ifdef ENABLE_NVJPEG
nvjpeg_params_.batch_size = img_names.size();
std::vector<nvjpegImage_t> output_imgs(nvjpeg_params_.batch_size);
std::vector<int> widths(nvjpeg_params_.batch_size);
std::vector<int> heights(nvjpeg_params_.batch_size);
// TODO(wangxinyu): support other output format
nvjpeg_params_.fmt = NVJPEG_OUTPUT_BGRI;
double total;
nvjpeg_params_.stream = (*mats)[0].Stream();
std::vector<FDTensor*> output_buffers;
for (size_t i = 0; i < mats->size(); ++i) {
FDASSERT((*mats)[i].output_cache != nullptr,
"The output_cache of FDMat was not set.");
output_buffers.push_back((*mats)[i].output_cache);
}
if (nvjpeg::process_images(img_names, nvjpeg_params_, total, output_imgs,
output_buffers, widths, heights)) {
// If nvJPEG decode failed, will fallback to OpenCV,
// e.g. png format is not supported by nvJPEG
FDWARNING << "nvJPEG decode failed, falling back to OpenCV for this batch"
<< std::endl;
return ImplByOpenCV(img_names, mats);
}
for (size_t i = 0; i < mats->size(); ++i) {
(*mats)[i].mat_type = ProcLib::CUDA;
(*mats)[i].layout = Layout::HWC;
(*mats)[i].SetTensor(output_buffers[i]);
}
#else
FDASSERT(false, "FastDeploy didn't compile with NVJPEG.");
#endif
return true;
}
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,49 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/vision/common/processors/mat.h"
#include "fastdeploy/vision/common/image_decoder/nvjpeg_decoder.h"
namespace fastdeploy {
namespace vision {
enum class FASTDEPLOY_DECL ImageDecoderLib { OPENCV, NVJPEG };
class FASTDEPLOY_DECL ImageDecoder {
public:
explicit ImageDecoder(ImageDecoderLib lib = ImageDecoderLib::OPENCV);
~ImageDecoder();
bool Decode(const std::string& img_name, FDMat* mat);
bool BatchDecode(const std::vector<std::string>& img_names,
std::vector<FDMat>* mats);
private:
bool ImplByOpenCV(const std::vector<std::string>& img_names,
std::vector<FDMat>* mats);
bool ImplByNvJpeg(const std::vector<std::string>& img_names,
std::vector<FDMat>* mats);
ImageDecoderLib lib_ = ImageDecoderLib::OPENCV;
#ifdef ENABLE_NVJPEG
nvjpeg::decode_params_t nvjpeg_params_;
#endif
};
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,363 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Part of the following code in this file refs to
// https://github.com/CVCUDA/CV-CUDA/blob/release_v0.2.x/samples/common/NvDecoder.cpp
//
// Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Licensed under the Apache-2.0 license
// \brief
// \author NVIDIA
#ifdef ENABLE_NVJPEG
#include "fastdeploy/vision/common/image_decoder/nvjpeg_decoder.h"
namespace fastdeploy {
namespace vision {
namespace nvjpeg {
#define CHECK_CUDA(call) \
{ \
cudaError_t _e = (call); \
if (_e != cudaSuccess) { \
std::cout << "CUDA Runtime failure: '#" << _e << "' at " << __FILE__ \
<< ":" << __LINE__ << std::endl; \
exit(1); \
} \
}
#define CHECK_NVJPEG(call) \
{ \
nvjpegStatus_t _e = (call); \
if (_e != NVJPEG_STATUS_SUCCESS) { \
std::cout << "NVJPEG failure: '#" << _e << "' at " << __FILE__ << ":" \
<< __LINE__ << std::endl; \
exit(1); \
} \
}
static int dev_malloc(void** p, size_t s) { return (int)cudaMalloc(p, s); }
static int dev_free(void* p) { return (int)cudaFree(p); }
static int host_malloc(void** p, size_t s, unsigned int f) {
return (int)cudaHostAlloc(p, s, f);
}
static int host_free(void* p) { return (int)cudaFreeHost(p); }
static int read_images(const FileNames& image_names, FileData& raw_data,
std::vector<size_t>& raw_len) {
for (size_t i = 0; i < image_names.size(); ++i) {
if (image_names.size() == 0) {
std::cerr << "No valid images left in the input list, exit" << std::endl;
return EXIT_FAILURE;
}
// Read an image from disk.
std::ifstream input(image_names[i].c_str(),
std::ios::in | std::ios::binary | std::ios::ate);
if (!(input.is_open())) {
std::cerr << "Cannot open image: " << image_names[i] << std::endl;
FDASSERT(false, "Read file error.");
continue;
}
// Get the size
long unsigned int file_size = input.tellg();
input.seekg(0, std::ios::beg);
// resize if buffer is too small
if (raw_data[i].size() < file_size) {
raw_data[i].resize(file_size);
}
if (!input.read(raw_data[i].data(), file_size)) {
std::cerr << "Cannot read from file: " << image_names[i] << std::endl;
// image_names.erase(cur_iter);
FDASSERT(false, "Read file error.");
continue;
}
raw_len[i] = file_size;
}
return EXIT_SUCCESS;
}
// prepare buffers for RGBi output format
static int prepare_buffers(FileData& file_data, std::vector<size_t>& file_len,
std::vector<int>& img_width,
std::vector<int>& img_height,
std::vector<nvjpegImage_t>& ibuf,
std::vector<nvjpegImage_t>& isz,
std::vector<FDTensor*>& output_buffers,
const FileNames& current_names,
decode_params_t& params) {
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];
int channels;
nvjpegChromaSubsampling_t subsampling;
for (long unsigned int i = 0; i < file_data.size(); i++) {
nvjpegStatus_t status = nvjpegGetImageInfo(
params.nvjpeg_handle, (unsigned char*)file_data[i].data(), file_len[i],
&channels, &subsampling, widths, heights);
if (status != NVJPEG_STATUS_SUCCESS) {
std::cout << "NVJPEG failure: #" << status << " in nvjpegGetImageInfo."
<< std::endl;
return EXIT_FAILURE;
}
img_width[i] = widths[0];
img_height[i] = heights[0];
int mul = 1;
// in the case of interleaved RGB output, write only to single channel, but
// 3 samples at once
if (params.fmt == NVJPEG_OUTPUT_RGBI || params.fmt == NVJPEG_OUTPUT_BGRI) {
channels = 1;
mul = 3;
} else if (params.fmt == NVJPEG_OUTPUT_RGB ||
params.fmt == NVJPEG_OUTPUT_BGR) {
// in the case of rgb create 3 buffers with sizes of original image
channels = 3;
widths[1] = widths[2] = widths[0];
heights[1] = heights[2] = heights[0];
} else {
FDASSERT(false, "Unsupport NVJPEG output format: %d", params.fmt);
}
output_buffers[i]->Resize({heights[0], widths[0], mul * channels},
FDDataType::UINT8, "output_cache", Device::GPU);
uint8_t* cur_buffer = reinterpret_cast<uint8_t*>(output_buffers[i]->Data());
// realloc output buffer if required
for (int c = 0; c < channels; c++) {
int aw = mul * widths[c];
int ah = heights[c];
size_t sz = aw * ah;
ibuf[i].pitch[c] = aw;
if (sz > isz[i].pitch[c]) {
ibuf[i].channel[c] = cur_buffer;
cur_buffer = cur_buffer + sz;
isz[i].pitch[c] = sz;
}
}
}
return EXIT_SUCCESS;
}
static void create_decoupled_api_handles(decode_params_t& params) {
CHECK_NVJPEG(nvjpegDecoderCreate(params.nvjpeg_handle, NVJPEG_BACKEND_DEFAULT,
&params.nvjpeg_decoder));
CHECK_NVJPEG(nvjpegDecoderStateCreate(params.nvjpeg_handle,
params.nvjpeg_decoder,
&params.nvjpeg_decoupled_state));
CHECK_NVJPEG(nvjpegBufferPinnedCreate(params.nvjpeg_handle, NULL,
&params.pinned_buffers[0]));
CHECK_NVJPEG(nvjpegBufferPinnedCreate(params.nvjpeg_handle, NULL,
&params.pinned_buffers[1]));
CHECK_NVJPEG(nvjpegBufferDeviceCreate(params.nvjpeg_handle, NULL,
&params.device_buffer));
CHECK_NVJPEG(
nvjpegJpegStreamCreate(params.nvjpeg_handle, &params.jpeg_streams[0]));
CHECK_NVJPEG(
nvjpegJpegStreamCreate(params.nvjpeg_handle, &params.jpeg_streams[1]));
CHECK_NVJPEG(nvjpegDecodeParamsCreate(params.nvjpeg_handle,
&params.nvjpeg_decode_params));
}
static void destroy_decoupled_api_handles(decode_params_t& params) {
CHECK_NVJPEG(nvjpegDecodeParamsDestroy(params.nvjpeg_decode_params));
CHECK_NVJPEG(nvjpegJpegStreamDestroy(params.jpeg_streams[0]));
CHECK_NVJPEG(nvjpegJpegStreamDestroy(params.jpeg_streams[1]));
CHECK_NVJPEG(nvjpegBufferPinnedDestroy(params.pinned_buffers[0]));
CHECK_NVJPEG(nvjpegBufferPinnedDestroy(params.pinned_buffers[1]));
CHECK_NVJPEG(nvjpegBufferDeviceDestroy(params.device_buffer));
CHECK_NVJPEG(nvjpegJpegStateDestroy(params.nvjpeg_decoupled_state));
CHECK_NVJPEG(nvjpegDecoderDestroy(params.nvjpeg_decoder));
}
int decode_images(const FileData& img_data, const std::vector<size_t>& img_len,
std::vector<nvjpegImage_t>& out, decode_params_t& params,
double& time) {
CHECK_CUDA(cudaStreamSynchronize(params.stream));
std::vector<const unsigned char*> batched_bitstreams;
std::vector<size_t> batched_bitstreams_size;
std::vector<nvjpegImage_t> batched_output;
// bit-streams that batched decode cannot handle
std::vector<const unsigned char*> otherdecode_bitstreams;
std::vector<size_t> otherdecode_bitstreams_size;
std::vector<nvjpegImage_t> otherdecode_output;
if (params.hw_decode_available) {
for (int i = 0; i < params.batch_size; i++) {
// extract bitstream meta data to figure out whether a bit-stream can be
// decoded
nvjpegJpegStreamParseHeader(params.nvjpeg_handle,
(const unsigned char*)img_data[i].data(),
img_len[i], params.jpeg_streams[0]);
int isSupported = -1;
nvjpegDecodeBatchedSupported(params.nvjpeg_handle, params.jpeg_streams[0],
&isSupported);
if (isSupported == 0) {
batched_bitstreams.push_back((const unsigned char*)img_data[i].data());
batched_bitstreams_size.push_back(img_len[i]);
batched_output.push_back(out[i]);
} else {
otherdecode_bitstreams.push_back(
(const unsigned char*)img_data[i].data());
otherdecode_bitstreams_size.push_back(img_len[i]);
otherdecode_output.push_back(out[i]);
}
}
} else {
for (int i = 0; i < params.batch_size; i++) {
otherdecode_bitstreams.push_back(
(const unsigned char*)img_data[i].data());
otherdecode_bitstreams_size.push_back(img_len[i]);
otherdecode_output.push_back(out[i]);
}
}
if (batched_bitstreams.size() > 0) {
CHECK_NVJPEG(nvjpegDecodeBatchedInitialize(
params.nvjpeg_handle, params.nvjpeg_state, batched_bitstreams.size(), 1,
params.fmt));
CHECK_NVJPEG(nvjpegDecodeBatched(
params.nvjpeg_handle, params.nvjpeg_state, batched_bitstreams.data(),
batched_bitstreams_size.data(), batched_output.data(), params.stream));
}
if (otherdecode_bitstreams.size() > 0) {
CHECK_NVJPEG(nvjpegStateAttachDeviceBuffer(params.nvjpeg_decoupled_state,
params.device_buffer));
int buffer_index = 0;
CHECK_NVJPEG(nvjpegDecodeParamsSetOutputFormat(params.nvjpeg_decode_params,
params.fmt));
for (int i = 0; i < params.batch_size; i++) {
CHECK_NVJPEG(nvjpegJpegStreamParse(params.nvjpeg_handle,
otherdecode_bitstreams[i],
otherdecode_bitstreams_size[i], 0, 0,
params.jpeg_streams[buffer_index]));
CHECK_NVJPEG(nvjpegStateAttachPinnedBuffer(
params.nvjpeg_decoupled_state, params.pinned_buffers[buffer_index]));
CHECK_NVJPEG(nvjpegDecodeJpegHost(
params.nvjpeg_handle, params.nvjpeg_decoder,
params.nvjpeg_decoupled_state, params.nvjpeg_decode_params,
params.jpeg_streams[buffer_index]));
CHECK_CUDA(cudaStreamSynchronize(params.stream));
CHECK_NVJPEG(nvjpegDecodeJpegTransferToDevice(
params.nvjpeg_handle, params.nvjpeg_decoder,
params.nvjpeg_decoupled_state, params.jpeg_streams[buffer_index],
params.stream));
buffer_index = 1 - buffer_index; // switch pinned buffer in pipeline mode
// to avoid an extra sync
CHECK_NVJPEG(
nvjpegDecodeJpegDevice(params.nvjpeg_handle, params.nvjpeg_decoder,
params.nvjpeg_decoupled_state,
&otherdecode_output[i], params.stream));
}
}
return EXIT_SUCCESS;
}
double process_images(const FileNames& image_names, decode_params_t& params,
double& total, std::vector<nvjpegImage_t>& iout,
std::vector<FDTensor*>& output_buffers,
std::vector<int>& widths, std::vector<int>& heights) {
FDASSERT(image_names.size() == params.batch_size,
"Number of images and batch size must be equal.");
// vector for storing raw files and file lengths
FileData file_data(params.batch_size);
std::vector<size_t> file_len(params.batch_size);
FileNames current_names(params.batch_size);
// we wrap over image files to process total_images of files
auto file_iter = image_names.begin();
// output buffer sizes, for convenience
std::vector<nvjpegImage_t> isz(params.batch_size);
for (long unsigned int i = 0; i < iout.size(); i++) {
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
iout[i].channel[c] = NULL;
iout[i].pitch[c] = 0;
isz[i].pitch[c] = 0;
}
}
if (read_images(image_names, file_data, file_len)) {
return EXIT_FAILURE;
}
if (prepare_buffers(file_data, file_len, widths, heights, iout, isz,
output_buffers, image_names, params)) {
return EXIT_FAILURE;
}
double time;
if (decode_images(file_data, file_len, iout, params, time)) {
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}
void init_decoder(decode_params_t& params) {
params.hw_decode_available = true;
nvjpegDevAllocator_t dev_allocator = {&dev_malloc, &dev_free};
nvjpegPinnedAllocator_t pinned_allocator = {&host_malloc, &host_free};
nvjpegStatus_t status =
nvjpegCreateEx(NVJPEG_BACKEND_HARDWARE, &dev_allocator, &pinned_allocator,
NVJPEG_FLAGS_DEFAULT, &params.nvjpeg_handle);
if (status == NVJPEG_STATUS_ARCH_MISMATCH) {
std::cout << "Hardware Decoder not supported. "
"Falling back to default backend"
<< std::endl;
CHECK_NVJPEG(nvjpegCreateEx(NVJPEG_BACKEND_DEFAULT, &dev_allocator,
&pinned_allocator, NVJPEG_FLAGS_DEFAULT,
&params.nvjpeg_handle));
params.hw_decode_available = false;
} else {
CHECK_NVJPEG(status);
}
CHECK_NVJPEG(
nvjpegJpegStateCreate(params.nvjpeg_handle, &params.nvjpeg_state));
create_decoupled_api_handles(params);
}
void destroy_decoder(decode_params_t& params) {
destroy_decoupled_api_handles(params);
CHECK_NVJPEG(nvjpegJpegStateDestroy(params.nvjpeg_state));
CHECK_NVJPEG(nvjpegDestroy(params.nvjpeg_handle));
}
} // namespace nvjpeg
} // namespace vision
} // namespace fastdeploy
#endif // ENABLE_NVJPEG

View File

@@ -0,0 +1,69 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Part of the following code in this file refs to
// https://github.com/CVCUDA/CV-CUDA/blob/release_v0.2.x/samples/common/NvDecoder.h
//
// Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Licensed under the Apache-2.0 license
// \brief
// \author NVIDIA
#pragma once
#ifdef ENABLE_NVJPEG
#include "fastdeploy/core/fd_tensor.h"
#include <cuda_runtime_api.h>
#include <nvjpeg.h>
namespace fastdeploy {
namespace vision {
namespace nvjpeg {
typedef std::vector<std::string> FileNames;
typedef std::vector<std::vector<char>> FileData;
struct decode_params_t {
int batch_size;
nvjpegJpegState_t nvjpeg_state;
nvjpegHandle_t nvjpeg_handle;
cudaStream_t stream;
// used with decoupled API
nvjpegJpegState_t nvjpeg_decoupled_state;
nvjpegBufferPinned_t pinned_buffers[2]; // 2 buffers for pipelining
nvjpegBufferDevice_t device_buffer;
nvjpegJpegStream_t jpeg_streams[2]; // 2 streams for pipelining
nvjpegDecodeParams_t nvjpeg_decode_params;
nvjpegJpegDecoder_t nvjpeg_decoder;
nvjpegOutputFormat_t fmt;
bool hw_decode_available;
};
void init_decoder(decode_params_t& params);
void destroy_decoder(decode_params_t& params);
double process_images(const FileNames& image_names, decode_params_t& params,
double& total, std::vector<nvjpegImage_t>& iout,
std::vector<FDTensor*>& output_buffers,
std::vector<int>& widths, std::vector<int>& heights);
} // namespace nvjpeg
} // namespace vision
} // namespace fastdeploy
#endif // ENABLE_NVJPEG

View File

@@ -77,6 +77,16 @@ bool ProcessorManager::Run(std::vector<FDMat>* images,
} }
(*images)[i].input_cache = &input_caches_[i]; (*images)[i].input_cache = &input_caches_[i];
(*images)[i].output_cache = &output_caches_[i]; (*images)[i].output_cache = &output_caches_[i];
if ((*images)[i].mat_type == ProcLib::CUDA) {
// Make a copy of the input data ptr, so that the original data ptr of
// FDMat won't be modified.
auto fd_tensor = std::make_shared<FDTensor>();
fd_tensor->SetExternalData(
(*images)[i].Tensor()->shape, (*images)[i].Tensor()->Dtype(),
(*images)[i].Tensor()->Data(), (*images)[i].Tensor()->device,
(*images)[i].Tensor()->device_id);
(*images)[i].SetTensor(fd_tensor);
}
} }
bool ret = Apply(&image_batch, outputs); bool ret = Apply(&image_batch, outputs);

View File

@@ -35,6 +35,10 @@ class FASTDEPLOY_DECL ProcessorManager {
bool CudaUsed(); bool CudaUsed();
#ifdef WITH_GPU
cudaStream_t Stream() const { return stream_; }
#endif
void SetStream(FDMat* mat) { void SetStream(FDMat* mat) {
#ifdef WITH_GPU #ifdef WITH_GPU
mat->SetStream(stream_); mat->SetStream(stream_);
@@ -56,7 +60,7 @@ class FASTDEPLOY_DECL ProcessorManager {
int DeviceId() { return device_id_; } int DeviceId() { return device_id_; }
/** \brief Process the input image and prepare input tensors for runtime /** \brief Process the input images and prepare input tensors for runtime
* *
* \param[in] images The input image data list, all the elements are returned by cv::imread() * \param[in] images The input image data list, all the elements are returned by cv::imread()
* \param[in] outputs The output tensors which will feed in runtime * \param[in] outputs The output tensors which will feed in runtime

View File

@@ -37,7 +37,7 @@ cv::Mat* Mat::GetOpenCVMat() {
#ifdef WITH_GPU #ifdef WITH_GPU
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess, FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream."); "[ERROR] Error occurs while sync cuda stream.");
cpu_mat = CreateZeroCopyOpenCVMatFromTensor(fd_tensor); cpu_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor);
mat_type = ProcLib::OPENCV; mat_type = ProcLib::OPENCV;
device = Device::CPU; device = Device::CPU;
return &cpu_mat; return &cpu_mat;
@@ -59,29 +59,53 @@ void* Mat::Data() {
"fcv::Mat."); "fcv::Mat.");
#endif #endif
} else if (device == Device::GPU) { } else if (device == Device::GPU) {
return fd_tensor.Data(); return fd_tensor->Data();
} }
return cpu_mat.ptr(); return cpu_mat.ptr();
} }
FDTensor* Mat::Tensor() { FDTensor* Mat::Tensor() {
if (mat_type == ProcLib::OPENCV) { if (mat_type == ProcLib::OPENCV) {
ShareWithTensor(&fd_tensor); ShareWithTensor(fd_tensor.get());
} else if (mat_type == ProcLib::FLYCV) { } else if (mat_type == ProcLib::FLYCV) {
#ifdef ENABLE_FLYCV #ifdef ENABLE_FLYCV
cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat); cpu_mat = ConvertFlyCVMatToOpenCV(fcv_mat);
mat_type = ProcLib::OPENCV; mat_type = ProcLib::OPENCV;
ShareWithTensor(&fd_tensor); ShareWithTensor(fd_tensor.get());
#else #else
FDASSERT(false, "FastDeploy didn't compiled with FlyCV!"); FDASSERT(false, "FastDeploy didn't compiled with FlyCV!");
#endif #endif
} }
return &fd_tensor; return fd_tensor.get();
} }
void Mat::SetTensor(FDTensor* tensor) { void Mat::SetTensor(FDTensor* tensor) {
fd_tensor.SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(), fd_tensor->SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
tensor->device, tensor->device_id); tensor->device, tensor->device_id);
device = tensor->device;
if (layout == Layout::HWC) {
height = tensor->Shape()[0];
width = tensor->Shape()[1];
channels = tensor->Shape()[2];
} else if (layout == Layout::CHW) {
channels = tensor->Shape()[0];
height = tensor->Shape()[1];
width = tensor->Shape()[2];
}
}
void Mat::SetTensor(std::shared_ptr<FDTensor>& tensor) {
fd_tensor = tensor;
device = tensor->device;
if (layout == Layout::HWC) {
height = tensor->Shape()[0];
width = tensor->Shape()[1];
channels = tensor->Shape()[2];
} else if (layout == Layout::CHW) {
channels = tensor->Shape()[0];
height = tensor->Shape()[1];
width = tensor->Shape()[2];
}
} }
void Mat::ShareWithTensor(FDTensor* tensor) { void Mat::ShareWithTensor(FDTensor* tensor) {
@@ -134,7 +158,7 @@ void Mat::PrintInfo(const std::string& flag) {
#ifdef WITH_GPU #ifdef WITH_GPU
FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess, FDASSERT(cudaStreamSynchronize(stream) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream."); "[ERROR] Error occurs while sync cuda stream.");
cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(fd_tensor); cv::Mat tmp_mat = CreateZeroCopyOpenCVMatFromTensor(*fd_tensor);
cv::Scalar mean = cv::mean(tmp_mat); cv::Scalar mean = cv::mean(tmp_mat);
for (int i = 0; i < Channels(); ++i) { for (int i = 0; i < Channels(); ++i) {
std::cout << mean[i] << " "; std::cout << mean[i] << " ";
@@ -157,7 +181,7 @@ FDDataType Mat::Type() {
"fcv::Mat."); "fcv::Mat.");
#endif #endif
} else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) { } else if (mat_type == ProcLib::CUDA || mat_type == ProcLib::CVCUDA) {
return fd_tensor.Dtype(); return fd_tensor->Dtype();
} }
return OpenCVDataTypeToFD(cpu_mat.type()); return OpenCVDataTypeToFD(cpu_mat.type());
} }
@@ -262,6 +286,10 @@ FDTensor* CreateCachedGpuInputTensor(Mat* mat) {
#ifdef WITH_GPU #ifdef WITH_GPU
FDTensor* src = mat->Tensor(); FDTensor* src = mat->Tensor();
if (src->device == Device::GPU) { if (src->device == Device::GPU) {
if (src->Data() == mat->output_cache->Data()) {
std::swap(mat->input_cache, mat->output_cache);
std::swap(mat->input_cache->name, mat->output_cache->name);
}
return src; return src;
} else if (src->device == Device::CPU) { } else if (src->device == Device::CPU) {
// Mats on CPU, we need copy these tensors from CPU to GPU // Mats on CPU, we need copy these tensors from CPU to GPU

View File

@@ -49,7 +49,6 @@ struct FASTDEPLOY_DECL Mat {
#endif #endif
Mat(const Mat& mat) = default; Mat(const Mat& mat) = default;
// Move assignment
Mat& operator=(const Mat& mat) = default; Mat& operator=(const Mat& mat) = default;
// Move constructor // Move constructor
@@ -96,6 +95,8 @@ struct FASTDEPLOY_DECL Mat {
// Set fd_tensor // Set fd_tensor
void SetTensor(FDTensor* tensor); void SetTensor(FDTensor* tensor);
void SetTensor(std::shared_ptr<FDTensor>& tensor);
private: private:
int channels; int channels;
int height; int height;
@@ -109,7 +110,7 @@ struct FASTDEPLOY_DECL Mat {
#endif #endif
// Currently, fd_tensor is only used by CUDA and CV-CUDA, // Currently, fd_tensor is only used by CUDA and CV-CUDA,
// OpenCV and FlyCV are not using it. // OpenCV and FlyCV are not using it.
FDTensor fd_tensor; std::shared_ptr<FDTensor> fd_tensor = std::make_shared<FDTensor>();
public: public:
FDDataType Type(); FDDataType Type();

View File

@@ -27,7 +27,7 @@ void FDMatBatch::SetStream(cudaStream_t s) {
FDTensor* FDMatBatch::Tensor() { FDTensor* FDMatBatch::Tensor() {
if (has_batched_tensor) { if (has_batched_tensor) {
return &fd_tensor; return fd_tensor.get();
} }
FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.") FDASSERT(CheckShapeConsistency(mats), "Mats shapes are not consistent.")
// Each mat has its own tensor, // Each mat has its own tensor,
@@ -45,12 +45,12 @@ FDTensor* FDMatBatch::Tensor() {
num_bytes, device, false); num_bytes, device, false);
} }
SetTensor(input_cache); SetTensor(input_cache);
return &fd_tensor; return fd_tensor.get();
} }
void FDMatBatch::SetTensor(FDTensor* tensor) { void FDMatBatch::SetTensor(FDTensor* tensor) {
fd_tensor.SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(), fd_tensor->SetExternalData(tensor->Shape(), tensor->Dtype(), tensor->Data(),
tensor->device, tensor->device_id); tensor->device, tensor->device_id);
has_batched_tensor = true; has_batched_tensor = true;
} }

View File

@@ -29,7 +29,7 @@ struct FASTDEPLOY_DECL FDMatBatch {
// MatBatch is intialized with a list of mats, // MatBatch is intialized with a list of mats,
// the data is stored in the mats separately. // the data is stored in the mats separately.
// Call Tensor() function to get a batched 4-dimension tensor. // Call Tensor() function to get a batched 4-dimension tensor.
explicit FDMatBatch(std::vector<Mat>* _mats) { explicit FDMatBatch(std::vector<FDMat>* _mats) {
mats = _mats; mats = _mats;
layout = FDMatBatchLayout::NHWC; layout = FDMatBatchLayout::NHWC;
mat_type = ProcLib::OPENCV; mat_type = ProcLib::OPENCV;
@@ -44,7 +44,7 @@ struct FASTDEPLOY_DECL FDMatBatch {
#ifdef WITH_GPU #ifdef WITH_GPU
cudaStream_t stream = nullptr; cudaStream_t stream = nullptr;
#endif #endif
FDTensor fd_tensor; std::shared_ptr<FDTensor> fd_tensor = std::make_shared<FDTensor>();
public: public:
// When using CV-CUDA/CUDA, please set input/output cache, // When using CV-CUDA/CUDA, please set input/output cache,

View File

@@ -81,7 +81,7 @@ bool NormalizeAndPermute::ImplByCuda(FDMatBatch* mat_batch) {
// Prepare output tensor // Prepare output tensor
mat_batch->output_cache->Resize(src->Shape(), FDDataType::FP32, mat_batch->output_cache->Resize(src->Shape(), FDDataType::FP32,
"output_cache", Device::GPU); "batch_output_cache", Device::GPU);
// NHWC -> NCHW // NHWC -> NCHW
std::swap(mat_batch->output_cache->shape[1], std::swap(mat_batch->output_cache->shape[1],
mat_batch->output_cache->shape[3]); mat_batch->output_cache->shape[3]);

View File

@@ -1,13 +1,9 @@
English | [中文](README_CN.md) English | [中文](README_CN.md)
# Tutorials # Tutorials
This directory provides some tutorials for FastDeploy. For other model deployment, please refer to the example [FastDeploy/examples](../examples) directly. This directory provides some tutorials for FastDeploy. For other model deployment, please refer to the example [FastDeploy/examples](../examples) directly.
- Intel independent graphics card/integrated graphics card deployment [see intel_gpu](intel_gpu) - Intel independent graphics card/integrated graphics card deployment [see intel_gpu](intel_gpu)
- Model multithreaded call [see multi_thread](multi_thread) - Model multithreaded call [see multi_thread](multi_thread)
- Image decoding, including hardward decoding, e.g. nvJPEG [image_decoder](image_decoder)

View File

@@ -7,3 +7,4 @@
- Intel独立显卡/集成显卡部署 [见intel_gpu](intel_gpu) - Intel独立显卡/集成显卡部署 [见intel_gpu](intel_gpu)
- 模型多线程调用 [见multi_thread](multi_thread) - 模型多线程调用 [见multi_thread](multi_thread)
- 图片解码含nvJPEG硬解码 [见image_decoder](image_decoder)

View File

@@ -0,0 +1,16 @@
English | [中文](README_CN.md)
# Image Decoder
Currently, we support below image decoder libs
- OpenCV
- nvJPEG (Needs NVIDIA GPU, doesn't support Jetson)
## Example
- [C++ Example](cpp)
- Python API(WIP)
## nvJPEG vs. OpenCV performance benchmark
Refer to: https://github.com/PaddlePaddle/FastDeploy/pull/1288#issuecomment-1427749772

View File

@@ -0,0 +1,16 @@
简体中文 | [English](README.md)
# Image Decoder
图片解码库,目前支持以下图片解码库:
- OpenCV
- nvJPEG (依赖NVIDIA GPU不支持Jetson)
## 示例代码
- [C++示例](cpp)
- Python API仍在开发中...
## nvJPEG和OpenCV性能对比数据
参见https://github.com/PaddlePaddle/FastDeploy/pull/1288#issuecomment-1427749772

View File

@@ -0,0 +1,11 @@
PROJECT(image_decoder C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
include_directories(${FASTDEPLOY_INCS})
add_executable(image_decoder ${PROJECT_SOURCE_DIR}/main.cc)
target_link_libraries(image_decoder ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,22 @@
English | [中文](README_CN.md)
# Image Decoder C++ Example
1. [Build FastDeploy](../docs/cn/build_and_install) or download [FastDeploy prebuilt library](../docs/cn/build_and_install/download_prebuilt_libraries.md)
2. Build example
```bash
mkdir build
cd build
# [PATH-TO-FASTDEPLOY] is the install directory of FastDeploy
cmake .. -DFASTDEPLOY_INSTALL_DIR=[PATH-TO-FASTDEPLOY]
make -j
# Download the test image
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
# OpenCV decoder
./image_decoder ILSVRC2012_val_00000010.jpeg 0
# nvJPEG
./image_decoder ILSVRC2012_val_00000010.jpeg 1

View File

@@ -0,0 +1,22 @@
简体中文 | [English](README.md)
# Image Decoder C++示例
1. [编译FastDeploy](../docs/cn/build_and_install), 或直接下载[FastDeploy预编译库](../docs/cn/build_and_install/download_prebuilt_libraries.md)
2. 编译示例
```bash
mkdir build
cd build
# [PATH-TO-FASTDEPLOY]需替换为FastDeploy的安装路径
cmake .. -DFASTDEPLOY_INSTALL_DIR=[PATH-TO-FASTDEPLOY]
make -j
# 下载测试图片
wget https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
# OpenCV解码
./image_decoder ILSVRC2012_val_00000010.jpeg 0
# nvJPEG
./image_decoder ILSVRC2012_val_00000010.jpeg 1

View File

@@ -0,0 +1,57 @@
#include "fastdeploy/vision/common/image_decoder/image_decoder.h"
namespace fdvis = fastdeploy::vision;
namespace fd = fastdeploy;
void OpenCVImageDecode(const std::string& img_name) {
fdvis::FDMat mat;
auto img_decoder = new fdvis::ImageDecoder();
img_decoder->Decode(img_name, &mat);
mat.PrintInfo("");
delete img_decoder;
}
void NvJpegImageDecode(const std::string& img_name) {
std::vector<fdvis::FDMat> mats(1);
std::vector<fastdeploy::FDTensor> caches(1);
cudaStream_t stream;
cudaStreamCreate(&stream);
// For nvJPEG decoder, we need set stream and output cache for the FDMat
for (size_t i = 0; i < mats.size(); i++) {
mats[i].output_cache = &caches[i];
mats[i].SetStream(stream);
}
auto img_decoder = new fdvis::ImageDecoder(fdvis::ImageDecoderLib::NVJPEG);
// This is batch decode API, for single image decode API,
// please refer to OpenCVImageDecode()
img_decoder->BatchDecode({img_name}, &mats);
for (size_t i = 0; i < mats.size(); i++) {
std::cout << "Mat type: " << mats[i].mat_type << ", "
<< "DataType=" << mats[i].Type() << ", "
<< "Channel=" << mats[i].Channels() << ", "
<< "Height=" << mats[i].Height() << ", "
<< "Width=" << mats[i].Width() << std::endl;
}
cudaStreamDestroy(stream);
}
int main(int argc, char* argv[]) {
if (argc < 3) {
std::cout << "Usage: image_decoder path/to/image run_option, "
"e.g ./image_decoder ./test.jpeg 0"
<< std::endl;
std::cout << "Run_option 0: OpenCV; 1: nvJPEG " << std::endl;
return -1;
}
if (std::atoi(argv[2]) == 0) {
OpenCVImageDecode(argv[1]);
} else if (std::atoi(argv[2]) == 1) {
NvJpegImageDecode(argv[1]);
}
return 0;
}