mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
[Other] Optimize code style (#1032)
* Optimize code * optimize code * optimize code * fix compile error
This commit is contained in:
@@ -19,7 +19,6 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "fastdeploy/backends/common/multiclass_nms.h"
|
|
||||||
#include "fastdeploy/core/fd_tensor.h"
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
#include "fastdeploy/core/fd_type.h"
|
#include "fastdeploy/core/fd_type.h"
|
||||||
|
|
||||||
|
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
__global__ void CudaCastKernel(const float* in, float* out, int edge, int out_bc_offset, int in_bc_offset, int ih, int iw, int oh, int ow, bool is_avg) {
|
__global__ void CudaCastKernel(const float* in, float* out, int edge,
|
||||||
|
int out_bc_offset, int in_bc_offset, int ih,
|
||||||
|
int iw, int oh, int ow, bool is_avg) {
|
||||||
int position = blockDim.x * blockIdx.x + threadIdx.x;
|
int position = blockDim.x * blockIdx.x + threadIdx.x;
|
||||||
if (position >= edge) {
|
if (position >= edge) {
|
||||||
return;
|
return;
|
||||||
@@ -14,38 +16,41 @@ __global__ void CudaCastKernel(const float* in, float* out, int edge, int out_b
|
|||||||
int hend = ceilf(static_cast<float>((h + 1) * ih) / oh);
|
int hend = ceilf(static_cast<float>((h + 1) * ih) / oh);
|
||||||
int wstart = floorf(static_cast<float>(w * iw) / ow);
|
int wstart = floorf(static_cast<float>(w * iw) / ow);
|
||||||
int wend = ceilf(static_cast<float>((w + 1) * iw) / ow);
|
int wend = ceilf(static_cast<float>((w + 1) * iw) / ow);
|
||||||
if(is_avg) {
|
if (is_avg) {
|
||||||
out[position] = 0.0;
|
out[position] = 0.0;
|
||||||
} else {
|
} else {
|
||||||
out[position] = in[offset * in_bc_offset + hstart * iw + wstart];
|
out[position] = in[offset * in_bc_offset + hstart * iw + wstart];
|
||||||
}
|
}
|
||||||
for (int h = hstart; h < hend; ++h) {
|
for (int h = hstart; h < hend; ++h) {
|
||||||
for (int w = wstart; w < wend; ++w) {
|
for (int w = wstart; w < wend; ++w) {
|
||||||
int input_idx = h * iw + w;
|
int input_idx = h * iw + w;
|
||||||
if(is_avg) {
|
if (is_avg) {
|
||||||
out[position] = out[position] + in[offset * in_bc_offset + input_idx];
|
out[position] = out[position] + in[offset * in_bc_offset + input_idx];
|
||||||
} else {
|
} else {
|
||||||
out[position] = max(out[position], in[offset * in_bc_offset + input_idx]);
|
out[position] =
|
||||||
|
max(out[position], in[offset * in_bc_offset + input_idx]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
out[position] = out[position] / ((hend - hstart) * (wend - wstart));
|
out[position] = out[position] / ((hend - hstart) * (wend - wstart));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CudaAdaptivePool(const std::vector<int64_t>& input_dims, const std::vector<int64_t>& output_dims, float* output, const float* input, void* compute_stream, const std::string& pooling_type){
|
void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
|
||||||
|
const std::vector<int64_t>& output_dims, float* output,
|
||||||
|
const float* input, void* compute_stream,
|
||||||
|
const std::string& pooling_type) {
|
||||||
auto casted_compute_stream = reinterpret_cast<cudaStream_t>(compute_stream);
|
auto casted_compute_stream = reinterpret_cast<cudaStream_t>(compute_stream);
|
||||||
int out_bc_offset = output_dims[2] * output_dims[3];
|
int out_bc_offset = output_dims[2] * output_dims[3];
|
||||||
int in_bc_offset = input_dims[2] * input_dims[3];
|
int in_bc_offset = input_dims[2] * input_dims[3];
|
||||||
int jobs = 1;
|
int jobs = 1;
|
||||||
for(int i : output_dims) {
|
for (int i : output_dims) {
|
||||||
jobs *= i;
|
jobs *= i;
|
||||||
}
|
}
|
||||||
bool is_avg = pooling_type == "avg";
|
bool is_avg = pooling_type == "avg";
|
||||||
int threads = 256;
|
int threads = 256;
|
||||||
int blocks = ceil(jobs / static_cast<float>(threads));
|
int blocks = ceil(jobs / static_cast<float>(threads));
|
||||||
CudaCastKernel<<<blocks, threads, 0, casted_compute_stream>>>(
|
CudaCastKernel<<<blocks, threads, 0, casted_compute_stream>>>(
|
||||||
input,
|
input, output, jobs, out_bc_offset, in_bc_offset, int(input_dims[2]),
|
||||||
output,
|
int(input_dims[3]), int(output_dims[2]), int(output_dims[3]), is_avg);
|
||||||
jobs, out_bc_offset, in_bc_offset, int(input_dims[2]), int(input_dims[3]), int(output_dims[2]), int(output_dims[3]), is_avg);
|
|
||||||
}
|
}
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -15,21 +15,18 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_runtime.h>
|
#include <cuda_runtime.h>
|
||||||
#include <cstdint>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <vector>
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
|
void CudaAdaptivePool(const std::vector<int64_t>& input_dims,
|
||||||
const std::vector<int64_t>& output_dims,
|
const std::vector<int64_t>& output_dims, float* output,
|
||||||
float* output,
|
const float* input, void* compute_stream,
|
||||||
const float* input,
|
|
||||||
void* compute_stream,
|
|
||||||
const std::string& pooling_type);
|
const std::string& pooling_type);
|
||||||
|
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -341,8 +341,7 @@ int OpenVINOBackend::NumInputs() const { return input_infos_.size(); }
|
|||||||
int OpenVINOBackend::NumOutputs() const { return output_infos_.size(); }
|
int OpenVINOBackend::NumOutputs() const { return output_infos_.size(); }
|
||||||
|
|
||||||
bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
|
bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
|
||||||
std::vector<FDTensor>* outputs,
|
std::vector<FDTensor>* outputs, bool copy_to_fd) {
|
||||||
bool copy_to_fd) {
|
|
||||||
if (inputs.size() != input_infos_.size()) {
|
if (inputs.size() != input_infos_.size()) {
|
||||||
FDERROR << "[OpenVINOBackend] Size of the inputs(" << inputs.size()
|
FDERROR << "[OpenVINOBackend] Size of the inputs(" << inputs.size()
|
||||||
<< ") should keep same with the inputs of this model("
|
<< ") should keep same with the inputs of this model("
|
||||||
@@ -365,19 +364,17 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
auto out_tensor_shape = out_tensor.get_shape();
|
auto out_tensor_shape = out_tensor.get_shape();
|
||||||
std::vector<int64_t> shape(out_tensor_shape.begin(),
|
std::vector<int64_t> shape(out_tensor_shape.begin(),
|
||||||
out_tensor_shape.end());
|
out_tensor_shape.end());
|
||||||
if(copy_to_fd) {
|
if (copy_to_fd) {
|
||||||
(*outputs)[i].Resize(shape,
|
(*outputs)[i].Resize(shape,
|
||||||
OpenVINODataTypeToFD(out_tensor.get_element_type()),
|
OpenVINODataTypeToFD(out_tensor.get_element_type()),
|
||||||
output_infos_[i].name,
|
output_infos_[i].name, Device::CPU);
|
||||||
Device::CPU);
|
|
||||||
memcpy((*outputs)[i].MutableData(), out_tensor.data(),
|
memcpy((*outputs)[i].MutableData(), out_tensor.data(),
|
||||||
(*outputs)[i].Nbytes());
|
(*outputs)[i].Nbytes());
|
||||||
} else {
|
} else {
|
||||||
(*outputs)[i].name = output_infos_[i].name;
|
(*outputs)[i].name = output_infos_[i].name;
|
||||||
(*outputs)[i].SetExternalData(shape,
|
(*outputs)[i].SetExternalData(
|
||||||
OpenVINODataTypeToFD(out_tensor.get_element_type()),
|
shape, OpenVINODataTypeToFD(out_tensor.get_element_type()),
|
||||||
out_tensor.data(),
|
out_tensor.data(), Device::CPU);
|
||||||
Device::CPU);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@@ -47,8 +47,7 @@ class OpenVINOBackend : public BaseBackend {
|
|||||||
InitFromOnnx(const std::string& model_file,
|
InitFromOnnx(const std::string& model_file,
|
||||||
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
||||||
|
|
||||||
bool Infer(std::vector<FDTensor>& inputs,
|
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
|
||||||
std::vector<FDTensor>* outputs,
|
|
||||||
bool copy_to_fd = true) override;
|
bool copy_to_fd = true) override;
|
||||||
|
|
||||||
int NumInputs() const override;
|
int NumInputs() const override;
|
||||||
|
@@ -25,30 +25,38 @@ struct OrtTensorDimensions : std::vector<int64_t> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
void AdaptivePool2dKernel::CpuAdaptivePool(const std::vector<int64_t>& input_size, const std::vector<int64_t>& output_size, const float* input_data, float* output_data){
|
void AdaptivePool2dKernel::CpuAdaptivePool(
|
||||||
|
const std::vector<int64_t>& input_size,
|
||||||
|
const std::vector<int64_t>& output_size, const float* input_data,
|
||||||
|
float* output_data) {
|
||||||
int64_t in_bc_offset = input_size[2] * input_size[3];
|
int64_t in_bc_offset = input_size[2] * input_size[3];
|
||||||
int64_t out_bc_offset = output_size[2] * output_size[3];
|
int64_t out_bc_offset = output_size[2] * output_size[3];
|
||||||
for (int64_t b = 0; b < output_size[0] ; b++) {
|
for (int64_t b = 0; b < output_size[0]; b++) {
|
||||||
for (int64_t c = 0; c < output_size[1] ; c++) {
|
for (int64_t c = 0; c < output_size[1]; c++) {
|
||||||
for(int64_t h = 0; h < output_size[2]; h++){
|
for (int64_t h = 0; h < output_size[2]; h++) {
|
||||||
int64_t hstart = std::floor( static_cast<float>(h * input_size[2]) / output_size[2]);
|
int64_t hstart =
|
||||||
int64_t hend = std::ceil(static_cast<float>((h + 1) * input_size[2]) / output_size[2]);
|
std::floor(static_cast<float>(h * input_size[2]) / output_size[2]);
|
||||||
for(int64_t w = 0; w < output_size[3]; w++){
|
int64_t hend = std::ceil(static_cast<float>((h + 1) * input_size[2]) /
|
||||||
int64_t wstart = std::floor(static_cast<float>(w * input_size[3]) / output_size[3]);
|
output_size[2]);
|
||||||
int64_t wend = std::ceil(static_cast<float>((w + 1) * input_size[3]) / output_size[3]);
|
for (int64_t w = 0; w < output_size[3]; w++) {
|
||||||
|
int64_t wstart = std::floor(static_cast<float>(w * input_size[3]) /
|
||||||
|
output_size[3]);
|
||||||
|
int64_t wend = std::ceil(static_cast<float>((w + 1) * input_size[3]) /
|
||||||
|
output_size[3]);
|
||||||
int64_t out_offset = h * output_size[3] + w;
|
int64_t out_offset = h * output_size[3] + w;
|
||||||
output_data[out_offset] = 0;
|
output_data[out_offset] = 0;
|
||||||
for(auto i = hstart; i < hend; i++){
|
for (auto i = hstart; i < hend; i++) {
|
||||||
for(auto j = wstart; j< wend; j++){
|
for (auto j = wstart; j < wend; j++) {
|
||||||
if(pooling_type_ == "avg"){
|
if (pooling_type_ == "avg") {
|
||||||
output_data[out_offset] += input_data[i * input_size[3] + j];
|
output_data[out_offset] += input_data[i * input_size[3] + j];
|
||||||
}
|
}
|
||||||
if(pooling_type_ == "max"){
|
if (pooling_type_ == "max") {
|
||||||
output_data[out_offset] = std::max(output_data[out_offset], input_data[i * input_size[3] + j]);
|
output_data[out_offset] = std::max(
|
||||||
|
output_data[out_offset], input_data[i * input_size[3] + j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if(pooling_type_ == "avg"){
|
if (pooling_type_ == "avg") {
|
||||||
output_data[out_offset] /= ((hend - hstart) * (wend - wstart));
|
output_data[out_offset] /= ((hend - hstart) * (wend - wstart));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -64,26 +72,27 @@ void AdaptivePool2dKernel::Compute(OrtKernelContext* context) {
|
|||||||
|
|
||||||
const float* input_data =
|
const float* input_data =
|
||||||
reinterpret_cast<const float*>(ort_.GetTensorData<float>(input));
|
reinterpret_cast<const float*>(ort_.GetTensorData<float>(input));
|
||||||
|
|
||||||
OrtTensorDimensions input_dim(ort_, input);
|
OrtTensorDimensions input_dim(ort_, input);
|
||||||
output_size_[0] = input_dim[0];
|
output_size_[0] = input_dim[0];
|
||||||
std::vector<int64_t> input_size;
|
std::vector<int64_t> input_size;
|
||||||
for(auto i: input_dim){
|
for (auto i : input_dim) {
|
||||||
input_size.push_back(i);
|
input_size.push_back(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
OrtValue* output = ort_.KernelContext_GetOutput(
|
OrtValue* output = ort_.KernelContext_GetOutput(
|
||||||
context, 0, output_size_.data(), output_size_.size());
|
context, 0, output_size_.data(), output_size_.size());
|
||||||
|
|
||||||
float* output_data = ort_.GetTensorMutableData<float>(output);
|
float* output_data = ort_.GetTensorMutableData<float>(output);
|
||||||
if(!strcmp(this->provider_, "CUDAExecutionProvider")){
|
if (!strcmp(this->provider_, "CUDAExecutionProvider")) {
|
||||||
#ifdef WITH_GPU
|
#ifdef WITH_GPU
|
||||||
auto compute_stream = ort_.KernelContext_GetGPUComputeStream(context);
|
auto compute_stream = ort_.KernelContext_GetGPUComputeStream(context);
|
||||||
CudaAdaptivePool(input_size, output_size_, output_data, input_data, compute_stream, pooling_type_);
|
CudaAdaptivePool(input_size, output_size_, output_data, input_data,
|
||||||
|
compute_stream, pooling_type_);
|
||||||
#else
|
#else
|
||||||
FDWARNING << "FastDeploy didn't compile with WITH_GPU. "
|
FDWARNING << "FastDeploy didn't compile with WITH_GPU. "
|
||||||
<< "Will force to use CPU to run." << std::endl;
|
<< "Will force to use CPU to run." << std::endl;
|
||||||
CpuAdaptivePool(input_size, output_size_, input_data, output_data);
|
CpuAdaptivePool(input_size, output_size_, input_data, output_data);
|
||||||
#endif
|
#endif
|
||||||
} else {
|
} else {
|
||||||
CpuAdaptivePool(input_size, output_size_, input_data, output_data);
|
CpuAdaptivePool(input_size, output_size_, input_data, output_data);
|
||||||
@@ -91,9 +100,13 @@ void AdaptivePool2dKernel::Compute(OrtKernelContext* context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void AdaptivePool2dKernel::GetAttribute(const OrtKernelInfo* info) {
|
void AdaptivePool2dKernel::GetAttribute(const OrtKernelInfo* info) {
|
||||||
pooling_type_ = ort_.KernelInfoGetAttribute<std::string>(info, "pooling_type");
|
pooling_type_ =
|
||||||
output_size_ = ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "output_size");
|
ort_.KernelInfoGetAttribute<std::string>(info, "pooling_type");
|
||||||
FDASSERT(output_size_.size() == 4 && output_size_[2] > 0 && output_size_[3] > 0, "The output size of adaptive pool must be positive.");
|
output_size_ =
|
||||||
|
ort_.KernelInfoGetAttribute<std::vector<int64_t>>(info, "output_size");
|
||||||
|
FDASSERT(output_size_.size() == 4 && output_size_[2] > 0 &&
|
||||||
|
output_size_[3] > 0,
|
||||||
|
"The output size of adaptive pool must be positive.");
|
||||||
}
|
}
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
|
||||||
|
@@ -14,12 +14,12 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <cmath>
|
|
||||||
#include "fastdeploy/core/fd_tensor.h"
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
#include "fastdeploy/utils/utils.h"
|
#include "fastdeploy/utils/utils.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#ifndef NON_64_PLATFORM
|
#ifndef NON_64_PLATFORM
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
@@ -38,9 +38,8 @@ struct AdaptivePool2dKernel {
|
|||||||
const char* provider_;
|
const char* provider_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
AdaptivePool2dKernel(Ort::CustomOpApi ort,
|
AdaptivePool2dKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info,
|
||||||
const OrtKernelInfo* info,
|
const char* provider)
|
||||||
const char* provider)
|
|
||||||
: ort_(ort) {
|
: ort_(ort) {
|
||||||
GetAttribute(info);
|
GetAttribute(info);
|
||||||
provider_ = provider;
|
provider_ = provider;
|
||||||
@@ -51,9 +50,8 @@ struct AdaptivePool2dKernel {
|
|||||||
void Compute(OrtKernelContext* context);
|
void Compute(OrtKernelContext* context);
|
||||||
|
|
||||||
void CpuAdaptivePool(const std::vector<int64_t>& input_size,
|
void CpuAdaptivePool(const std::vector<int64_t>& input_size,
|
||||||
const std::vector<int64_t>& output_size,
|
const std::vector<int64_t>& output_size,
|
||||||
const float* input_data,
|
const float* input_data, float* output_data);
|
||||||
float* output_data);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct AdaptivePool2dOp
|
struct AdaptivePool2dOp
|
||||||
@@ -77,9 +75,8 @@ struct AdaptivePool2dOp
|
|||||||
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* GetExecutionProviderType() const {
|
const char* GetExecutionProviderType() const { return provider_; }
|
||||||
return provider_;
|
|
||||||
}
|
|
||||||
private:
|
private:
|
||||||
const char* provider_;
|
const char* provider_;
|
||||||
};
|
};
|
||||||
|
@@ -15,9 +15,9 @@
|
|||||||
#ifndef NON_64_PLATFORM
|
#ifndef NON_64_PLATFORM
|
||||||
|
|
||||||
#include "fastdeploy/backends/ort/ops/multiclass_nms.h"
|
#include "fastdeploy/backends/ort/ops/multiclass_nms.h"
|
||||||
#include <algorithm>
|
|
||||||
#include "fastdeploy/core/fd_tensor.h"
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
#include "fastdeploy/utils/utils.h"
|
#include "fastdeploy/utils/utils.h"
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
@@ -16,8 +16,8 @@
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "fastdeploy/backends/ort/ops/multiclass_nms.h"
|
|
||||||
#include "fastdeploy/backends/ort/ops/adaptive_pool2d.h"
|
#include "fastdeploy/backends/ort/ops/adaptive_pool2d.h"
|
||||||
|
#include "fastdeploy/backends/ort/ops/multiclass_nms.h"
|
||||||
#include "fastdeploy/backends/ort/utils.h"
|
#include "fastdeploy/backends/ort/utils.h"
|
||||||
#include "fastdeploy/core/float16.h"
|
#include "fastdeploy/core/float16.h"
|
||||||
#include "fastdeploy/utils/utils.h"
|
#include "fastdeploy/utils/utils.h"
|
||||||
@@ -64,7 +64,7 @@ void OrtBackend::BuildOption(const OrtBackendOption& option) {
|
|||||||
} else {
|
} else {
|
||||||
OrtCUDAProviderOptions cuda_options;
|
OrtCUDAProviderOptions cuda_options;
|
||||||
cuda_options.device_id = option.gpu_id;
|
cuda_options.device_id = option.gpu_id;
|
||||||
if(option.external_stream_) {
|
if (option.external_stream_) {
|
||||||
cuda_options.has_user_compute_stream = 1;
|
cuda_options.has_user_compute_stream = 1;
|
||||||
cuda_options.user_compute_stream = option.external_stream_;
|
cuda_options.user_compute_stream = option.external_stream_;
|
||||||
}
|
}
|
||||||
@@ -91,11 +91,11 @@ bool OrtBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
strcpy(ops[0].export_op_name, "MultiClassNMS");
|
strcpy(ops[0].export_op_name, "MultiClassNMS");
|
||||||
strcpy(ops[1].op_name, "pool2d");
|
strcpy(ops[1].op_name, "pool2d");
|
||||||
strcpy(ops[1].export_op_name, "AdaptivePool2d");
|
strcpy(ops[1].export_op_name, "AdaptivePool2d");
|
||||||
|
|
||||||
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
|
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
|
||||||
&model_content_ptr, &model_content_size, 11, true,
|
&model_content_ptr, &model_content_size, 11, true,
|
||||||
verbose, true, true, true, ops.data(),
|
verbose, true, true, true, ops.data(), 2,
|
||||||
2, "onnxruntime", nullptr, 0, "", &save_external)) {
|
"onnxruntime", nullptr, 0, "", &save_external)) {
|
||||||
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
return false;
|
return false;
|
||||||
@@ -105,11 +105,11 @@ bool OrtBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
model_content_ptr + model_content_size);
|
model_content_ptr + model_content_size);
|
||||||
delete[] model_content_ptr;
|
delete[] model_content_ptr;
|
||||||
model_content_ptr = nullptr;
|
model_content_ptr = nullptr;
|
||||||
if(save_external){
|
if (save_external) {
|
||||||
std::string model_file_name = "model.onnx";
|
std::string model_file_name = "model.onnx";
|
||||||
std::fstream f(model_file_name, std::ios::out);
|
std::fstream f(model_file_name, std::ios::out);
|
||||||
FDASSERT(f.is_open(), "Can not open file: %s to save model.",
|
FDASSERT(f.is_open(), "Can not open file: %s to save model.",
|
||||||
model_file_name.c_str());
|
model_file_name.c_str());
|
||||||
f << onnx_model_proto;
|
f << onnx_model_proto;
|
||||||
f.close();
|
f.close();
|
||||||
return InitFromOnnx(model_file_name, option, false);
|
return InitFromOnnx(model_file_name, option, false);
|
||||||
@@ -182,7 +182,7 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void OrtBackend::OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor,
|
void OrtBackend::OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor,
|
||||||
const std::string& name, bool copy_to_fd) {
|
const std::string& name, bool copy_to_fd) {
|
||||||
const auto info = value.GetTensorTypeAndShapeInfo();
|
const auto info = value.GetTensorTypeAndShapeInfo();
|
||||||
const auto data_type = info.GetElementType();
|
const auto data_type = info.GetElementType();
|
||||||
size_t numel = info.GetElementCount();
|
size_t numel = info.GetElementCount();
|
||||||
@@ -216,15 +216,13 @@ void OrtBackend::OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor,
|
|||||||
memcpy(tensor->MutableData(), value_ptr, numel);
|
memcpy(tensor->MutableData(), value_ptr, numel);
|
||||||
} else {
|
} else {
|
||||||
tensor->name = name;
|
tensor->name = name;
|
||||||
tensor->SetExternalData(
|
tensor->SetExternalData(shape, dtype, const_cast<void*>(value_ptr),
|
||||||
shape, dtype,
|
Device::CPU);
|
||||||
const_cast<void*>(value_ptr), Device::CPU);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
|
bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||||
std::vector<FDTensor>* outputs,
|
std::vector<FDTensor>* outputs, bool copy_to_fd) {
|
||||||
bool copy_to_fd) {
|
|
||||||
if (inputs.size() != inputs_desc_.size()) {
|
if (inputs.size() != inputs_desc_.size()) {
|
||||||
FDERROR << "[OrtBackend] Size of the inputs(" << inputs.size()
|
FDERROR << "[OrtBackend] Size of the inputs(" << inputs.size()
|
||||||
<< ") should keep same with the inputs of this model("
|
<< ") should keep same with the inputs of this model("
|
||||||
@@ -256,8 +254,8 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
std::vector<Ort::Value> ort_outputs = binding_->GetOutputValues();
|
std::vector<Ort::Value> ort_outputs = binding_->GetOutputValues();
|
||||||
outputs->resize(ort_outputs.size());
|
outputs->resize(ort_outputs.size());
|
||||||
for (size_t i = 0; i < ort_outputs.size(); ++i) {
|
for (size_t i = 0; i < ort_outputs.size(); ++i) {
|
||||||
OrtValueToFDTensor(ort_outputs[i], &((*outputs)[i]),
|
OrtValueToFDTensor(ort_outputs[i], &((*outputs)[i]), outputs_desc_[i].name,
|
||||||
outputs_desc_[i].name, copy_to_fd);
|
copy_to_fd);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@@ -310,11 +308,13 @@ void OrtBackend::InitCustomOperators() {
|
|||||||
if (custom_operators_.size() == 0) {
|
if (custom_operators_.size() == 0) {
|
||||||
MultiClassNmsOp* multiclass_nms = new MultiClassNmsOp{};
|
MultiClassNmsOp* multiclass_nms = new MultiClassNmsOp{};
|
||||||
custom_operators_.push_back(multiclass_nms);
|
custom_operators_.push_back(multiclass_nms);
|
||||||
if(option_.use_gpu){
|
if (option_.use_gpu) {
|
||||||
AdaptivePool2dOp* adaptive_pool2d = new AdaptivePool2dOp{"CUDAExecutionProvider"};
|
AdaptivePool2dOp* adaptive_pool2d =
|
||||||
|
new AdaptivePool2dOp{"CUDAExecutionProvider"};
|
||||||
custom_operators_.push_back(adaptive_pool2d);
|
custom_operators_.push_back(adaptive_pool2d);
|
||||||
}else{
|
} else {
|
||||||
AdaptivePool2dOp* adaptive_pool2d = new AdaptivePool2dOp{"CPUExecutionProvider"};
|
AdaptivePool2dOp* adaptive_pool2d =
|
||||||
|
new AdaptivePool2dOp{"CPUExecutionProvider"};
|
||||||
custom_operators_.push_back(adaptive_pool2d);
|
custom_operators_.push_back(adaptive_pool2d);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -18,6 +18,7 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
#include "fastdeploy/backends/backend.h"
|
#include "fastdeploy/backends/backend.h"
|
||||||
#include "onnxruntime_cxx_api.h" // NOLINT
|
#include "onnxruntime_cxx_api.h" // NOLINT
|
||||||
@@ -67,8 +68,7 @@ class OrtBackend : public BaseBackend {
|
|||||||
const OrtBackendOption& option = OrtBackendOption(),
|
const OrtBackendOption& option = OrtBackendOption(),
|
||||||
bool from_memory_buffer = false);
|
bool from_memory_buffer = false);
|
||||||
|
|
||||||
bool Infer(std::vector<FDTensor>& inputs,
|
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
|
||||||
std::vector<FDTensor>* outputs,
|
|
||||||
bool copy_to_fd = true) override;
|
bool copy_to_fd = true) override;
|
||||||
|
|
||||||
int NumInputs() const override { return inputs_desc_.size(); }
|
int NumInputs() const override { return inputs_desc_.size(); }
|
||||||
|
@@ -104,7 +104,8 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
std::string contents;
|
std::string contents;
|
||||||
|
|
||||||
if (option.model_from_memory_) {
|
if (option.model_from_memory_) {
|
||||||
config_.SetModelBuffer(model_file.c_str(), option.model_buffer_size_, params_file.c_str(), option.params_buffer_size_);
|
config_.SetModelBuffer(model_file.c_str(), option.model_buffer_size_,
|
||||||
|
params_file.c_str(), option.params_buffer_size_);
|
||||||
contents = model_file;
|
contents = model_file;
|
||||||
} else {
|
} else {
|
||||||
config_.SetModel(model_file, params_file);
|
config_.SetModel(model_file, params_file);
|
||||||
@@ -182,7 +183,9 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
FDINFO << "Start generating shape range info file." << std::endl;
|
FDINFO << "Start generating shape range info file." << std::endl;
|
||||||
paddle_infer::Config analysis_config;
|
paddle_infer::Config analysis_config;
|
||||||
if (option.model_from_memory_) {
|
if (option.model_from_memory_) {
|
||||||
analysis_config.SetModelBuffer(model_file.c_str(), option.model_buffer_size_, params_file.c_str(), option.params_buffer_size_);
|
analysis_config.SetModelBuffer(
|
||||||
|
model_file.c_str(), option.model_buffer_size_, params_file.c_str(),
|
||||||
|
option.params_buffer_size_);
|
||||||
} else {
|
} else {
|
||||||
analysis_config.SetModel(model_file, params_file);
|
analysis_config.SetModel(model_file, params_file);
|
||||||
}
|
}
|
||||||
|
@@ -30,24 +30,24 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
|
|||||||
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
|
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
|
||||||
if (fd_tensor.dtype == FDDataType::FP32) {
|
if (fd_tensor.dtype == FDDataType::FP32) {
|
||||||
if (place == paddle_infer::PlaceType::kGPU) {
|
if (place == paddle_infer::PlaceType::kGPU) {
|
||||||
tensor->ShareExternalData(static_cast<const float*>(fd_tensor.Data()),
|
tensor->ShareExternalData(static_cast<const float*>(fd_tensor.Data()),
|
||||||
shape, place);
|
shape, place);
|
||||||
} else {
|
} else {
|
||||||
tensor->CopyFromCpu(static_cast<const float*>(fd_tensor.Data()));
|
tensor->CopyFromCpu(static_cast<const float*>(fd_tensor.Data()));
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
} else if (fd_tensor.dtype == FDDataType::INT32) {
|
} else if (fd_tensor.dtype == FDDataType::INT32) {
|
||||||
if (place == paddle_infer::PlaceType::kGPU) {
|
if (place == paddle_infer::PlaceType::kGPU) {
|
||||||
tensor->ShareExternalData(static_cast<const int32_t*>(fd_tensor.Data()),
|
tensor->ShareExternalData(static_cast<const int32_t*>(fd_tensor.Data()),
|
||||||
shape, place);
|
shape, place);
|
||||||
} else {
|
} else {
|
||||||
tensor->CopyFromCpu(static_cast<const int32_t*>(fd_tensor.Data()));
|
tensor->CopyFromCpu(static_cast<const int32_t*>(fd_tensor.Data()));
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
} else if (fd_tensor.dtype == FDDataType::INT64) {
|
} else if (fd_tensor.dtype == FDDataType::INT64) {
|
||||||
if (place == paddle_infer::PlaceType::kGPU) {
|
if (place == paddle_infer::PlaceType::kGPU) {
|
||||||
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
|
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
|
||||||
shape, place);
|
shape, place);
|
||||||
} else {
|
} else {
|
||||||
tensor->CopyFromCpu(static_cast<const int64_t*>(fd_tensor.Data()));
|
tensor->CopyFromCpu(static_cast<const int64_t*>(fd_tensor.Data()));
|
||||||
}
|
}
|
||||||
@@ -62,13 +62,12 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||||
FDTensor* fd_tensor,
|
FDTensor* fd_tensor, bool copy_to_fd) {
|
||||||
bool copy_to_fd) {
|
|
||||||
auto fd_dtype = PaddleDataTypeToFD(tensor->type());
|
auto fd_dtype = PaddleDataTypeToFD(tensor->type());
|
||||||
std::vector<int64_t> shape;
|
std::vector<int64_t> shape;
|
||||||
auto tmp_shape = tensor->shape();
|
auto tmp_shape = tensor->shape();
|
||||||
shape.assign(tmp_shape.begin(), tmp_shape.end());
|
shape.assign(tmp_shape.begin(), tmp_shape.end());
|
||||||
if(copy_to_fd) {
|
if (copy_to_fd) {
|
||||||
fd_tensor->Resize(shape, fd_dtype, tensor->name());
|
fd_tensor->Resize(shape, fd_dtype, tensor->name());
|
||||||
if (fd_tensor->dtype == FDDataType::FP32) {
|
if (fd_tensor->dtype == FDDataType::FP32) {
|
||||||
tensor->CopyToCpu(static_cast<float*>(fd_tensor->MutableData()));
|
tensor->CopyToCpu(static_cast<float*>(fd_tensor->MutableData()));
|
||||||
@@ -79,9 +78,9 @@ void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
|||||||
} else if (fd_tensor->dtype == FDDataType::INT64) {
|
} else if (fd_tensor->dtype == FDDataType::INT64) {
|
||||||
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor->MutableData()));
|
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor->MutableData()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
|
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
|
||||||
Str(fd_tensor->dtype).c_str());
|
Str(fd_tensor->dtype).c_str());
|
||||||
} else {
|
} else {
|
||||||
paddle_infer::PlaceType place;
|
paddle_infer::PlaceType place;
|
||||||
int size = 0;
|
int size = 0;
|
||||||
@@ -99,17 +98,17 @@ void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
|||||||
} else if (fd_dtype == FDDataType::UINT8) {
|
} else if (fd_dtype == FDDataType::UINT8) {
|
||||||
out_data = tensor->data<uint8_t>(&place, &size);
|
out_data = tensor->data<uint8_t>(&place, &size);
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(false, "Unexpected data type(%s) while infer shared with PaddleBackend.",
|
FDASSERT(
|
||||||
|
false,
|
||||||
|
"Unexpected data type(%s) while infer shared with PaddleBackend.",
|
||||||
Str(fd_dtype).c_str());
|
Str(fd_dtype).c_str());
|
||||||
}
|
}
|
||||||
Device device = Device::CPU;
|
Device device = Device::CPU;
|
||||||
if(place == paddle_infer::PlaceType::kGPU) {
|
if (place == paddle_infer::PlaceType::kGPU) {
|
||||||
device = Device::GPU;
|
device = Device::GPU;
|
||||||
}
|
}
|
||||||
fd_tensor->name = tensor->name();
|
fd_tensor->name = tensor->name();
|
||||||
fd_tensor->SetExternalData(
|
fd_tensor->SetExternalData(shape, fd_dtype, out_data, device);
|
||||||
shape, fd_dtype,
|
|
||||||
out_data, device);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -153,7 +152,10 @@ FDDataType ReaderDataTypeToFD(int32_t dtype) {
|
|||||||
} else if (dtype == 6) {
|
} else if (dtype == 6) {
|
||||||
fd_dtype = FDDataType::FP16;
|
fd_dtype = FDDataType::FP16;
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(false, "Unexpected data type: %d while call ReaderDataTypeToFD in PaddleBackend.", dtype);
|
FDASSERT(false,
|
||||||
|
"Unexpected data type: %d while call ReaderDataTypeToFD in "
|
||||||
|
"PaddleBackend.",
|
||||||
|
dtype);
|
||||||
}
|
}
|
||||||
return fd_dtype;
|
return fd_dtype;
|
||||||
}
|
}
|
||||||
|
@@ -14,14 +14,14 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <unordered_map>
|
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "torch/script.h"
|
|
||||||
#include "iengine.h"
|
#include "iengine.h"
|
||||||
#include "poros_module.h"
|
#include "poros_module.h"
|
||||||
|
#include "torch/script.h"
|
||||||
|
|
||||||
namespace baidu {
|
namespace baidu {
|
||||||
namespace mirana {
|
namespace mirana {
|
||||||
@@ -36,28 +36,29 @@ namespace poros {
|
|||||||
* @return porosmodule
|
* @return porosmodule
|
||||||
* @retval !nullptr => succeed nullptr => failed
|
* @retval !nullptr => succeed nullptr => failed
|
||||||
**/
|
**/
|
||||||
std::unique_ptr<PorosModule> Compile(const torch::jit::Module& module,
|
std::unique_ptr<PorosModule>
|
||||||
const std::vector<std::vector<c10::IValue> >& prewarm_datas,
|
Compile(const torch::jit::Module& module,
|
||||||
|
const std::vector<std::vector<c10::IValue>>& prewarm_datas,
|
||||||
const PorosOptions& options);
|
const PorosOptions& options);
|
||||||
|
|
||||||
class Compiler {
|
class Compiler {
|
||||||
public:
|
public:
|
||||||
typedef std::unordered_map<const torch::jit::Node*, IEngine*> engine_map_t;
|
typedef std::unordered_map<const torch::jit::Node*, IEngine*> engine_map_t;
|
||||||
typedef std::vector<std::vector<c10::IValue> > ivalue_vec_t;
|
typedef std::vector<std::vector<c10::IValue>> ivalue_vec_t;
|
||||||
|
|
||||||
Compiler() : _origin_module(NULL) {}
|
Compiler() : _origin_module(NULL) {}
|
||||||
~Compiler();
|
~Compiler();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief initial Compiler
|
* @brief initial Compiler
|
||||||
*
|
*
|
||||||
* @param [in] options : poros options
|
* @param [in] options : poros options
|
||||||
* @return int
|
* @return int
|
||||||
* @retval 0 => succeed <0 => failed
|
* @retval 0 => succeed <0 => failed
|
||||||
**/
|
**/
|
||||||
int init(const PorosOptions& options);
|
int init(const PorosOptions& options);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief compile whole graph
|
* @brief compile whole graph
|
||||||
*
|
*
|
||||||
* @param [in] origin_module
|
* @param [in] origin_module
|
||||||
@@ -66,13 +67,12 @@ public:
|
|||||||
* @return int
|
* @return int
|
||||||
* @retval 0 => succeed <0 => failed
|
* @retval 0 => succeed <0 => failed
|
||||||
**/
|
**/
|
||||||
int compile(const torch::jit::Module& origin_module,
|
int compile(const torch::jit::Module& origin_module,
|
||||||
const ivalue_vec_t& prewarm_datas,
|
const ivalue_vec_t& prewarm_datas,
|
||||||
torch::jit::Module* optimized_module);
|
torch::jit::Module* optimized_module);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
/**
|
||||||
/**
|
|
||||||
* @brief preprocess this calculation graph
|
* @brief preprocess this calculation graph
|
||||||
*
|
*
|
||||||
* @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
|
* @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
|
||||||
@@ -80,23 +80,25 @@ private:
|
|||||||
* @return int
|
* @return int
|
||||||
* @retval 0 => succeed <0 => failed
|
* @retval 0 => succeed <0 => failed
|
||||||
**/
|
**/
|
||||||
int preprocess_graph(const ivalue_vec_t& prewarm_datas, std::shared_ptr<torch::jit::Graph>& graph);
|
int preprocess_graph(const ivalue_vec_t& prewarm_datas,
|
||||||
|
std::shared_ptr<torch::jit::Graph>& graph);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief segement this calculation graph
|
* @brief segement this calculation graph
|
||||||
*
|
*
|
||||||
* @param [in/out] graph
|
* @param [in/out] graph
|
||||||
* @return int
|
* @return int
|
||||||
* @retval 0 => succeed <0 => failed
|
* @retval 0 => succeed <0 => failed
|
||||||
**/
|
**/
|
||||||
int segment_graph(std::shared_ptr<torch::jit::Graph>& graph);
|
int segment_graph(std::shared_ptr<torch::jit::Graph>& graph);
|
||||||
|
|
||||||
// Split subgraph(block)
|
// Split subgraph(block)
|
||||||
// The divided subgraph, as a subgraph, is associated with the block
|
// The divided subgraph, as a subgraph, is associated with the block
|
||||||
int segment_block(torch::jit::Block& block, IEngine* engine, int current_depth);
|
int segment_block(torch::jit::Block& block, IEngine* engine,
|
||||||
|
int current_depth);
|
||||||
|
|
||||||
// Subgraph optimization
|
// Subgraph optimization
|
||||||
/**
|
/**
|
||||||
* @brief Subgraph optimization
|
* @brief Subgraph optimization
|
||||||
*
|
*
|
||||||
* @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
|
* @param [in] prewarm_datas : ivalue_vec_t, vector of IValue
|
||||||
@@ -105,15 +107,15 @@ private:
|
|||||||
* @return int
|
* @return int
|
||||||
* @retval 0 => succeed <0 => failed
|
* @retval 0 => succeed <0 => failed
|
||||||
**/
|
**/
|
||||||
int optimize_subgraph(const ivalue_vec_t& prewarm_datas,
|
int optimize_subgraph(const ivalue_vec_t& prewarm_datas,
|
||||||
const std::shared_ptr<torch::jit::Graph>& opt_graph,
|
const std::shared_ptr<torch::jit::Graph>& opt_graph,
|
||||||
torch::jit::Module* optimized_module);
|
torch::jit::Module* optimized_module);
|
||||||
|
|
||||||
// Subgraph optimization(block)
|
// Subgraph optimization(block)
|
||||||
int optimize_subblock(torch::jit::Block* block,
|
int optimize_subblock(torch::jit::Block* block,
|
||||||
torch::jit::Module* optimized_module);
|
torch::jit::Module* optimized_module);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compile the subgraph into a new graph based on the engine
|
* @brief Compile the subgraph into a new graph based on the engine
|
||||||
*
|
*
|
||||||
* @param [in] engine : The engine used by the subgraph
|
* @param [in] engine : The engine used by the subgraph
|
||||||
@@ -121,32 +123,32 @@ private:
|
|||||||
* @return [out] module : Transformed model
|
* @return [out] module : Transformed model
|
||||||
* @retval 0 => succeed <0 => failed
|
* @retval 0 => succeed <0 => failed
|
||||||
**/
|
**/
|
||||||
int transform(IEngine* engine, torch::jit::Node& subgraph_node,
|
int transform(IEngine* engine, torch::jit::Node& subgraph_node,
|
||||||
torch::jit::Module& module);
|
torch::jit::Module& module);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Select engine based on subgraph and options
|
* @brief Select engine based on subgraph and options
|
||||||
*
|
*
|
||||||
* @param [in] node : Jit Node
|
* @param [in] node : Jit Node
|
||||||
* @return int
|
* @return int
|
||||||
* @retval 0 => succeed <0 => failed
|
* @retval 0 => succeed <0 => failed
|
||||||
**/
|
**/
|
||||||
IEngine* select_engine(const torch::jit::Node* n);
|
IEngine* select_engine(const torch::jit::Node* n);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief destory
|
* @brief destory
|
||||||
*
|
*
|
||||||
* @return void
|
* @return void
|
||||||
**/
|
**/
|
||||||
void close();
|
void close();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int _max_segment_depth{5}; // Maximum subgraph segmentation depth
|
int _max_segment_depth{5}; // Maximum subgraph segmentation depth
|
||||||
ivalue_vec_t _prewarm_datas; // Prewarm datas
|
ivalue_vec_t _prewarm_datas; // Prewarm datas
|
||||||
PorosOptions _options;
|
PorosOptions _options;
|
||||||
engine_map_t _engine_map; // The engine used to record the subgraph
|
engine_map_t _engine_map; // The engine used to record the subgraph
|
||||||
const torch::jit::Module* _origin_module; // Origin_module
|
const torch::jit::Module* _origin_module; // Origin_module
|
||||||
std::atomic<int> _engine_index = {0}; // Record engine index
|
std::atomic<int> _engine_index = {0}; // Record engine index
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -158,9 +160,10 @@ private:
|
|||||||
* @return optimized_module
|
* @return optimized_module
|
||||||
* @retval !nullptr => succeed nullptr => failed
|
* @retval !nullptr => succeed nullptr => failed
|
||||||
**/
|
**/
|
||||||
std::unique_ptr<torch::jit::Module> CompileGraph(const torch::jit::Module& module,
|
std::unique_ptr<torch::jit::Module>
|
||||||
const std::vector<std::vector<c10::IValue> >& prewarm_datas,
|
CompileGraph(const torch::jit::Module& module,
|
||||||
const PorosOptions& options);
|
const std::vector<std::vector<c10::IValue>>& prewarm_datas,
|
||||||
|
const PorosOptions& options);
|
||||||
|
|
||||||
} // namespace poros
|
} // namespace poros
|
||||||
} // namespace mirana
|
} // namespace mirana
|
||||||
|
@@ -17,9 +17,9 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
//from pytorch
|
//from pytorch
|
||||||
#include "torch/script.h"
|
|
||||||
#include "torch/csrc/jit/ir/ir.h"
|
|
||||||
#include "ATen/core/interned_strings.h"
|
#include "ATen/core/interned_strings.h"
|
||||||
|
#include "torch/csrc/jit/ir/ir.h"
|
||||||
|
#include "torch/script.h"
|
||||||
|
|
||||||
#include "plugin_create.h"
|
#include "plugin_create.h"
|
||||||
|
|
||||||
@@ -28,50 +28,51 @@ namespace mirana {
|
|||||||
namespace poros {
|
namespace poros {
|
||||||
|
|
||||||
struct PorosGraph {
|
struct PorosGraph {
|
||||||
torch::jit::Graph* graph = NULL;
|
torch::jit::Graph* graph = NULL;
|
||||||
torch::jit::Node* node = NULL;
|
torch::jit::Node* node = NULL;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef uint64_t EngineID;
|
typedef uint64_t EngineID;
|
||||||
|
|
||||||
class IEngine : public IPlugin, public torch::CustomClassHolder{
|
class IEngine : public IPlugin, public torch::CustomClassHolder {
|
||||||
public:
|
public:
|
||||||
virtual ~IEngine() {}
|
virtual ~IEngine() {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief init, initialization must be successful if the init is successful
|
* @brief init, initialization must be successful if the init is successful
|
||||||
* @return int
|
* @return int
|
||||||
* @retval 0 => success, <0 => fail
|
* @retval 0 => success, <0 => fail
|
||||||
**/
|
**/
|
||||||
virtual int init() = 0;
|
virtual int init() = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief During compilation, the subgraph is converted into the graph structure of the corresponding engine and stored inside the engine, so that the execute_engine at runtime can be called
|
* @brief During compilation, the subgraph is converted into the graph structure of the corresponding engine and stored inside the engine, so that the execute_engine at runtime can be called
|
||||||
* @param [in] sub_graph : subgraph
|
* @param [in] sub_graph : subgraph
|
||||||
* @return [res]int
|
* @return [res]int
|
||||||
* @retval 0 => success, <0 => fail
|
* @retval 0 => success, <0 => fail
|
||||||
**/
|
**/
|
||||||
virtual int transform(const PorosGraph& sub_graph) = 0;
|
virtual int transform(const PorosGraph& sub_graph) = 0;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Subgraph execution period logic
|
* @brief Subgraph execution period logic
|
||||||
* @param [in] inputs : input tensor
|
* @param [in] inputs : input tensor
|
||||||
* @return [res] output tensor
|
* @return [res] output tensor
|
||||||
**/
|
**/
|
||||||
virtual std::vector<at::Tensor> excute_engine(const std::vector<at::Tensor>& inputs) = 0;
|
virtual std::vector<at::Tensor>
|
||||||
|
excute_engine(const std::vector<at::Tensor>& inputs) = 0;
|
||||||
|
|
||||||
virtual void register_module_attribute(const std::string& name, torch::jit::Module& module) = 0;
|
virtual void register_module_attribute(const std::string& name,
|
||||||
|
torch::jit::Module& module) = 0;
|
||||||
|
|
||||||
// Logo
|
// Logo
|
||||||
virtual const std::string who_am_i() = 0;
|
virtual const std::string who_am_i() = 0;
|
||||||
|
|
||||||
// Whether the node is supported by the current engine
|
// Whether the node is supported by the current engine
|
||||||
bool is_node_supported(const torch::jit::Node* node);
|
bool is_node_supported(const torch::jit::Node* node);
|
||||||
|
|
||||||
public:
|
|
||||||
std::pair<uint64_t, uint64_t> _num_io; // Number of input/output parameters
|
|
||||||
EngineID _id;
|
|
||||||
|
|
||||||
|
public:
|
||||||
|
std::pair<uint64_t, uint64_t> _num_io; // Number of input/output parameters
|
||||||
|
EngineID _id;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace poros
|
} // namespace poros
|
||||||
|
@@ -14,52 +14,56 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <unordered_map>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
namespace baidu {
|
namespace baidu {
|
||||||
namespace mirana {
|
namespace mirana {
|
||||||
namespace poros {
|
namespace poros {
|
||||||
|
|
||||||
class IPlugin {
|
class IPlugin {
|
||||||
public:
|
public:
|
||||||
virtual ~IPlugin() {}
|
virtual ~IPlugin() {}
|
||||||
virtual const std::string who_am_i() = 0;
|
virtual const std::string who_am_i() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef IPlugin* (*plugin_creator_t)();
|
typedef IPlugin* (*plugin_creator_t)();
|
||||||
typedef std::unordered_map<std::string, plugin_creator_t> plugin_creator_map_t;
|
typedef std::unordered_map<std::string, plugin_creator_t> plugin_creator_map_t;
|
||||||
|
|
||||||
IPlugin* create_plugin(const std::string& plugin_name);
|
IPlugin* create_plugin(const std::string& plugin_name);
|
||||||
IPlugin* create_plugin(const std::string& plugin_name, const plugin_creator_map_t& plugin_creator_map);
|
IPlugin* create_plugin(const std::string& plugin_name,
|
||||||
|
const plugin_creator_map_t& plugin_creator_map);
|
||||||
|
|
||||||
void create_all_plugins(const plugin_creator_map_t& plugin_creator_map,
|
void create_all_plugins(const plugin_creator_map_t& plugin_creator_map,
|
||||||
std::unordered_map<std::string, IPlugin*>& plugin_m);
|
std::unordered_map<std::string, IPlugin*>& plugin_m);
|
||||||
//void create_all_plugins(std::unordered_map<std::string, IPlugin*>& plugin_m);
|
//void create_all_plugins(std::unordered_map<std::string, IPlugin*>& plugin_m);
|
||||||
|
|
||||||
template <typename PluginType>
|
template <typename PluginType> IPlugin* default_plugin_creator() {
|
||||||
IPlugin* default_plugin_creator() {
|
return new (std::nothrow) PluginType;
|
||||||
return new (std::nothrow)PluginType;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void register_plugin_creator(const std::string& plugin_name, plugin_creator_t creator);
|
|
||||||
void register_plugin_creator(const std::string& plugin_name,
|
void register_plugin_creator(const std::string& plugin_name,
|
||||||
plugin_creator_t creator, plugin_creator_map_t& plugin_creator_map);
|
plugin_creator_t creator);
|
||||||
|
void register_plugin_creator(const std::string& plugin_name,
|
||||||
|
plugin_creator_t creator,
|
||||||
|
plugin_creator_map_t& plugin_creator_map);
|
||||||
|
|
||||||
template <typename PluginType>
|
template <typename PluginType>
|
||||||
void register_plugin_class(const std::string& plugin_name) {
|
void register_plugin_class(const std::string& plugin_name) {
|
||||||
return register_plugin_creator(plugin_name, default_plugin_creator<PluginType>);
|
return register_plugin_creator(plugin_name,
|
||||||
|
default_plugin_creator<PluginType>);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This version is recommended
|
// This version is recommended
|
||||||
template <typename PluginType>
|
template <typename PluginType>
|
||||||
void register_plugin_class(const std::string& plugin_name, plugin_creator_map_t& plugin_creator_map) {
|
void register_plugin_class(const std::string& plugin_name,
|
||||||
return register_plugin_creator(plugin_name, default_plugin_creator<PluginType>, plugin_creator_map);
|
plugin_creator_map_t& plugin_creator_map) {
|
||||||
|
return register_plugin_creator(
|
||||||
|
plugin_name, default_plugin_creator<PluginType>, plugin_creator_map);
|
||||||
}
|
}
|
||||||
|
|
||||||
}//poros
|
} // namespace poros
|
||||||
}//mirana
|
} // namespace mirana
|
||||||
}//baidu
|
} // namespace baidu
|
||||||
|
|
||||||
|
|
||||||
/* vim: set ts=4 sw=4 sts=4 tw=100 */
|
/* vim: set ts=4 sw=4 sts=4 tw=100 */
|
||||||
|
@@ -14,53 +14,45 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include "torch/script.h"
|
|
||||||
#include "torch/csrc/jit/jit_log.h"
|
#include "torch/csrc/jit/jit_log.h"
|
||||||
|
#include "torch/script.h"
|
||||||
|
#include <string>
|
||||||
// #include "ATen/Context.h"
|
// #include "ATen/Context.h"
|
||||||
|
|
||||||
namespace baidu {
|
namespace baidu {
|
||||||
namespace mirana {
|
namespace mirana {
|
||||||
namespace poros {
|
namespace poros {
|
||||||
|
|
||||||
enum Device : int8_t {
|
enum Device : int8_t { GPU = 0, CPU, XPU, UNKNOW };
|
||||||
GPU = 0,
|
|
||||||
CPU,
|
|
||||||
XPU,
|
|
||||||
UNKNOW
|
|
||||||
};
|
|
||||||
|
|
||||||
struct PorosOptions {
|
struct PorosOptions {
|
||||||
Device device = GPU;
|
Device device = GPU;
|
||||||
bool debug = false;
|
bool debug = false;
|
||||||
bool use_fp16 = false;
|
bool use_fp16 = false;
|
||||||
bool is_dynamic = false;
|
bool is_dynamic = false;
|
||||||
bool long_to_int = true;
|
bool long_to_int = true;
|
||||||
uint64_t max_workspace_size = 1ULL << 30;
|
uint64_t max_workspace_size = 1ULL << 30;
|
||||||
int32_t device_id = -1;
|
int32_t device_id = -1;
|
||||||
int32_t unconst_ops_thres = -1;
|
int32_t unconst_ops_thres = -1;
|
||||||
bool use_nvidia_tf32 = false;
|
bool use_nvidia_tf32 = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
class PorosModule : public torch::jit::Module {
|
class PorosModule : public torch::jit::Module {
|
||||||
public:
|
public:
|
||||||
PorosModule(torch::jit::Module module) : torch::jit::Module(module) {
|
PorosModule(torch::jit::Module module) : torch::jit::Module(module) {}
|
||||||
}
|
~PorosModule() = default;
|
||||||
~PorosModule() = default;
|
|
||||||
|
|
||||||
void to_device(Device device){
|
void to_device(Device device) { _options.device = device; }
|
||||||
_options.device = device;
|
|
||||||
}
|
|
||||||
|
|
||||||
//c10::IValue forward(std::vector<c10::IValue> inputs);
|
|
||||||
//void save(const std::string& filename);
|
|
||||||
public:
|
|
||||||
PorosOptions _options;
|
|
||||||
|
|
||||||
|
//c10::IValue forward(std::vector<c10::IValue> inputs);
|
||||||
|
//void save(const std::string& filename);
|
||||||
|
public:
|
||||||
|
PorosOptions _options;
|
||||||
};
|
};
|
||||||
|
|
||||||
//via porosmodule.save
|
//via porosmodule.save
|
||||||
std::unique_ptr<PorosModule> Load(const std::string& filename, const PorosOptions& options);
|
std::unique_ptr<PorosModule> Load(const std::string& filename,
|
||||||
|
const PorosOptions& options);
|
||||||
|
|
||||||
} // namespace poros
|
} // namespace poros
|
||||||
} // namespace mirana
|
} // namespace mirana
|
||||||
|
@@ -188,8 +188,7 @@ bool PorosBackend::InitFromPoros(const std::string& model_file,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool PorosBackend::Infer(std::vector<FDTensor>& inputs,
|
bool PorosBackend::Infer(std::vector<FDTensor>& inputs,
|
||||||
std::vector<FDTensor>* outputs,
|
std::vector<FDTensor>* outputs, bool copy_to_fd) {
|
||||||
bool copy_to_fd) {
|
|
||||||
// Convert FD Tensor to PyTorch Tensor
|
// Convert FD Tensor to PyTorch Tensor
|
||||||
std::vector<torch::jit::IValue> poros_inputs;
|
std::vector<torch::jit::IValue> poros_inputs;
|
||||||
bool is_backend_cuda =
|
bool is_backend_cuda =
|
||||||
|
@@ -74,9 +74,9 @@ class PorosBackend : public BaseBackend {
|
|||||||
|
|
||||||
void BuildOption(const PorosBackendOption& option);
|
void BuildOption(const PorosBackendOption& option);
|
||||||
|
|
||||||
bool InitFromTorchScript(
|
bool
|
||||||
const std::string& model_file,
|
InitFromTorchScript(const std::string& model_file,
|
||||||
const PorosBackendOption& option = PorosBackendOption());
|
const PorosBackendOption& option = PorosBackendOption());
|
||||||
|
|
||||||
bool InitFromPoros(const std::string& model_file,
|
bool InitFromPoros(const std::string& model_file,
|
||||||
const PorosBackendOption& option = PorosBackendOption());
|
const PorosBackendOption& option = PorosBackendOption());
|
||||||
@@ -85,8 +85,7 @@ class PorosBackend : public BaseBackend {
|
|||||||
std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
||||||
const PorosBackendOption& option = PorosBackendOption());
|
const PorosBackendOption& option = PorosBackendOption());
|
||||||
|
|
||||||
bool Infer(std::vector<FDTensor>& inputs,
|
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
|
||||||
std::vector<FDTensor>* outputs,
|
|
||||||
bool copy_to_fd = true) override;
|
bool copy_to_fd = true) override;
|
||||||
|
|
||||||
int NumInputs() const { return _numinputs; }
|
int NumInputs() const { return _numinputs; }
|
||||||
|
@@ -23,32 +23,32 @@ namespace fastdeploy {
|
|||||||
std::string AtType2String(const at::ScalarType& dtype) {
|
std::string AtType2String(const at::ScalarType& dtype) {
|
||||||
std::string out;
|
std::string out;
|
||||||
switch (dtype) {
|
switch (dtype) {
|
||||||
case at::kByte:
|
case at::kByte:
|
||||||
out = "at::kByte";
|
out = "at::kByte";
|
||||||
break;
|
break;
|
||||||
case at::kChar:
|
case at::kChar:
|
||||||
out = "at::kChar";
|
out = "at::kChar";
|
||||||
break;
|
break;
|
||||||
case at::kShort:
|
case at::kShort:
|
||||||
out = "at::kShort";
|
out = "at::kShort";
|
||||||
break;
|
break;
|
||||||
case at::kInt:
|
case at::kInt:
|
||||||
out = "at::kInt";
|
out = "at::kInt";
|
||||||
break;
|
break;
|
||||||
case at::kLong:
|
case at::kLong:
|
||||||
out = "at::kLong";
|
out = "at::kLong";
|
||||||
break;
|
break;
|
||||||
case at::kHalf:
|
case at::kHalf:
|
||||||
out = "at::kHalf";
|
out = "at::kHalf";
|
||||||
break;
|
break;
|
||||||
case at::kFloat:
|
case at::kFloat:
|
||||||
out = "at::kFloat";
|
out = "at::kFloat";
|
||||||
break;
|
break;
|
||||||
case at::kDouble:
|
case at::kDouble:
|
||||||
out = "at::kDouble";
|
out = "at::kDouble";
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
out = "at::UNKNOWN";
|
out = "at::UNKNOWN";
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
@@ -129,9 +129,8 @@ at::Tensor CreatePorosValue(FDTensor& tensor, bool is_backend_cuda) {
|
|||||||
numel * sizeof(double));
|
numel * sizeof(double));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(false,
|
FDASSERT(false, "Unrecognized data type while calling "
|
||||||
"Unrecognized data type while calling "
|
"PorosBackend::CreatePorosValue().");
|
||||||
"PorosBackend::CreatePorosValue().");
|
|
||||||
}
|
}
|
||||||
return poros_value;
|
return poros_value;
|
||||||
}
|
}
|
||||||
|
@@ -27,14 +27,14 @@ RKNPU2Backend::~RKNPU2Backend() {
|
|||||||
for (uint32_t i = 0; i < io_num.n_input; i++) {
|
for (uint32_t i = 0; i < io_num.n_input; i++) {
|
||||||
rknn_destroy_mem(ctx, input_mems_[i]);
|
rknn_destroy_mem(ctx, input_mems_[i]);
|
||||||
}
|
}
|
||||||
if(input_mems_ != nullptr){
|
if (input_mems_ != nullptr) {
|
||||||
free(input_mems_);
|
free(input_mems_);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (uint32_t i = 0; i < io_num.n_output; i++) {
|
for (uint32_t i = 0; i < io_num.n_output; i++) {
|
||||||
rknn_destroy_mem(ctx, output_mems_[i]);
|
rknn_destroy_mem(ctx, output_mems_[i]);
|
||||||
}
|
}
|
||||||
if(output_mems_ != nullptr){
|
if (output_mems_ != nullptr) {
|
||||||
free(output_mems_);
|
free(output_mems_);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -173,16 +173,15 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
|
|||||||
|
|
||||||
// create input tensor memory
|
// create input tensor memory
|
||||||
// rknn_tensor_mem* input_mems[io_num.n_input];
|
// rknn_tensor_mem* input_mems[io_num.n_input];
|
||||||
input_mems_ = (rknn_tensor_mem**)malloc(sizeof(rknn_tensor_mem*) * io_num.n_input);
|
input_mems_ =
|
||||||
|
(rknn_tensor_mem**)malloc(sizeof(rknn_tensor_mem*) * io_num.n_input);
|
||||||
|
|
||||||
// get input info and copy to input tensor info
|
// get input info and copy to input tensor info
|
||||||
for (uint32_t i = 0; i < io_num.n_input; i++) {
|
for (uint32_t i = 0; i < io_num.n_input; i++) {
|
||||||
input_attrs_[i].index = i;
|
input_attrs_[i].index = i;
|
||||||
|
|
||||||
// query info
|
// query info
|
||||||
ret = rknn_query(ctx,
|
ret = rknn_query(ctx, RKNN_QUERY_INPUT_ATTR, &(input_attrs_[i]),
|
||||||
RKNN_QUERY_INPUT_ATTR,
|
|
||||||
&(input_attrs_[i]),
|
|
||||||
sizeof(rknn_tensor_attr));
|
sizeof(rknn_tensor_attr));
|
||||||
DumpTensorAttr(input_attrs_[i]);
|
DumpTensorAttr(input_attrs_[i]);
|
||||||
|
|
||||||
@@ -190,12 +189,12 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
|
|||||||
printf("rknn_init error! ret=%d\n", ret);
|
printf("rknn_init error! ret=%d\n", ret);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if((input_attrs_[i].fmt != RKNN_TENSOR_NHWC) &&
|
if ((input_attrs_[i].fmt != RKNN_TENSOR_NHWC) &&
|
||||||
(input_attrs_[i].fmt != RKNN_TENSOR_UNDEFINED)){
|
(input_attrs_[i].fmt != RKNN_TENSOR_UNDEFINED)) {
|
||||||
FDERROR << "rknpu2_backend only support input format is NHWC or UNDEFINED" << std::endl;
|
FDERROR << "rknpu2_backend only support input format is NHWC or UNDEFINED"
|
||||||
|
<< std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// copy input_attrs_ to input tensor info
|
// copy input_attrs_ to input tensor info
|
||||||
std::string temp_name = input_attrs_[i].name;
|
std::string temp_name = input_attrs_[i].name;
|
||||||
std::vector<int> temp_shape{};
|
std::vector<int> temp_shape{};
|
||||||
@@ -203,25 +202,28 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
|
|||||||
for (int j = 0; j < input_attrs_[i].n_dims; j++) {
|
for (int j = 0; j < input_attrs_[i].n_dims; j++) {
|
||||||
temp_shape[j] = (int)input_attrs_[i].dims[j];
|
temp_shape[j] = (int)input_attrs_[i].dims[j];
|
||||||
}
|
}
|
||||||
FDDataType temp_dtype = fastdeploy::RKNPU2Backend::RknnTensorTypeToFDDataType(input_attrs_[i].type);
|
FDDataType temp_dtype =
|
||||||
|
fastdeploy::RKNPU2Backend::RknnTensorTypeToFDDataType(
|
||||||
|
input_attrs_[i].type);
|
||||||
TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype};
|
TensorInfo temp_input_info = {temp_name, temp_shape, temp_dtype};
|
||||||
inputs_desc_[i] = temp_input_info;
|
inputs_desc_[i] = temp_input_info;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get detailed output parameters
|
// Get detailed output parameters
|
||||||
output_attrs_ = (rknn_tensor_attr*)malloc(sizeof(rknn_tensor_attr) * io_num.n_output);
|
output_attrs_ =
|
||||||
|
(rknn_tensor_attr*)malloc(sizeof(rknn_tensor_attr) * io_num.n_output);
|
||||||
memset(output_attrs_, 0, io_num.n_output * sizeof(rknn_tensor_attr));
|
memset(output_attrs_, 0, io_num.n_output * sizeof(rknn_tensor_attr));
|
||||||
outputs_desc_.resize(io_num.n_output);
|
outputs_desc_.resize(io_num.n_output);
|
||||||
|
|
||||||
// Create output tensor memory
|
// Create output tensor memory
|
||||||
output_mems_ = (rknn_tensor_mem**)malloc(sizeof(rknn_tensor_mem*) * io_num.n_output);;
|
output_mems_ =
|
||||||
|
(rknn_tensor_mem**)malloc(sizeof(rknn_tensor_mem*) * io_num.n_output);
|
||||||
|
;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < io_num.n_output; i++) {
|
for (uint32_t i = 0; i < io_num.n_output; i++) {
|
||||||
output_attrs_[i].index = i;
|
output_attrs_[i].index = i;
|
||||||
// query info
|
// query info
|
||||||
ret = rknn_query(ctx,
|
ret = rknn_query(ctx, RKNN_QUERY_OUTPUT_ATTR, &(output_attrs_[i]),
|
||||||
RKNN_QUERY_OUTPUT_ATTR,
|
|
||||||
&(output_attrs_[i]),
|
|
||||||
sizeof(rknn_tensor_attr));
|
sizeof(rknn_tensor_attr));
|
||||||
DumpTensorAttr(output_attrs_[i]);
|
DumpTensorAttr(output_attrs_[i]);
|
||||||
|
|
||||||
@@ -233,7 +235,7 @@ bool RKNPU2Backend::GetModelInputOutputInfos() {
|
|||||||
// If the output dimension is 3, the runtime will automatically change it to 4.
|
// If the output dimension is 3, the runtime will automatically change it to 4.
|
||||||
// Obviously, this is wrong, and manual correction is required here.
|
// Obviously, this is wrong, and manual correction is required here.
|
||||||
int n_dims = output_attrs_[i].n_dims;
|
int n_dims = output_attrs_[i].n_dims;
|
||||||
if((n_dims == 4) && (output_attrs_[i].dims[3] == 1)){
|
if ((n_dims == 4) && (output_attrs_[i].dims[3] == 1)) {
|
||||||
n_dims--;
|
n_dims--;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -292,8 +294,7 @@ std::vector<TensorInfo> RKNPU2Backend::GetOutputInfos() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
||||||
std::vector<FDTensor>* outputs,
|
std::vector<FDTensor>* outputs, bool copy_to_fd) {
|
||||||
bool copy_to_fd) {
|
|
||||||
int ret = RKNN_SUCC;
|
int ret = RKNN_SUCC;
|
||||||
// Judge whether the input and output size are the same
|
// Judge whether the input and output size are the same
|
||||||
if (inputs.size() != inputs_desc_.size()) {
|
if (inputs.size() != inputs_desc_.size()) {
|
||||||
@@ -303,15 +304,17 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(!this->infer_init){
|
if (!this->infer_init) {
|
||||||
for (uint32_t i = 0; i < io_num.n_input; i++) {
|
for (uint32_t i = 0; i < io_num.n_input; i++) {
|
||||||
// Judge whether the input and output types are the same
|
// Judge whether the input and output types are the same
|
||||||
rknn_tensor_type input_type =
|
rknn_tensor_type input_type =
|
||||||
fastdeploy::RKNPU2Backend::FDDataTypeToRknnTensorType(inputs[i].dtype);
|
fastdeploy::RKNPU2Backend::FDDataTypeToRknnTensorType(
|
||||||
|
inputs[i].dtype);
|
||||||
if (input_type != input_attrs_[i].type) {
|
if (input_type != input_attrs_[i].type) {
|
||||||
FDWARNING << "The input tensor type != model's inputs type."
|
FDWARNING << "The input tensor type != model's inputs type."
|
||||||
<< "The input_type need " << get_type_string(input_attrs_[i].type)
|
<< "The input_type need "
|
||||||
<< ",but inputs["<< i << "].type is " << get_type_string(input_type)
|
<< get_type_string(input_attrs_[i].type) << ",but inputs["
|
||||||
|
<< i << "].type is " << get_type_string(input_type)
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -319,10 +322,11 @@ bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
input_attrs_[i].type = input_type;
|
input_attrs_[i].type = input_type;
|
||||||
input_attrs_[i].size = inputs[0].Nbytes();
|
input_attrs_[i].size = inputs[0].Nbytes();
|
||||||
input_attrs_[i].size_with_stride = inputs[0].Nbytes();
|
input_attrs_[i].size_with_stride = inputs[0].Nbytes();
|
||||||
if(input_attrs_[i].type == RKNN_TENSOR_FLOAT16 ||
|
if (input_attrs_[i].type == RKNN_TENSOR_FLOAT16 ||
|
||||||
input_attrs_[i].type == RKNN_TENSOR_FLOAT32){
|
input_attrs_[i].type == RKNN_TENSOR_FLOAT32) {
|
||||||
FDINFO << "The input model is not a quantitative model. "
|
FDINFO << "The input model is not a quantitative model. "
|
||||||
"Close the normalize operation." << std::endl;
|
"Close the normalize operation."
|
||||||
|
<< std::endl;
|
||||||
}
|
}
|
||||||
|
|
||||||
input_mems_[i] = rknn_create_mem(ctx, inputs[i].Nbytes());
|
input_mems_[i] = rknn_create_mem(ctx, inputs[i].Nbytes());
|
||||||
@@ -474,4 +478,4 @@ RKNPU2Backend::FDDataTypeToRknnTensorType(fastdeploy::FDDataType type) {
|
|||||||
FDERROR << "rknn_tensor_type don't support this type" << std::endl;
|
FDERROR << "rknn_tensor_type don't support this type" << std::endl;
|
||||||
return RKNN_TENSOR_TYPE_MAX;
|
return RKNN_TENSOR_TYPE_MAX;
|
||||||
}
|
}
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -14,9 +14,9 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "fastdeploy/backends/backend.h"
|
#include "fastdeploy/backends/backend.h"
|
||||||
#include "fastdeploy/core/fd_tensor.h"
|
|
||||||
#include "rknn_api.h" // NOLINT
|
|
||||||
#include "fastdeploy/backends/rknpu/rknpu2/rknpu2_config.h"
|
#include "fastdeploy/backends/rknpu/rknpu2/rknpu2_config.h"
|
||||||
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
|
#include "rknn_api.h" // NOLINT
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@@ -71,8 +71,7 @@ class RKNPU2Backend : public BaseBackend {
|
|||||||
TensorInfo GetOutputInfo(int index) override;
|
TensorInfo GetOutputInfo(int index) override;
|
||||||
std::vector<TensorInfo> GetInputInfos() override;
|
std::vector<TensorInfo> GetInputInfos() override;
|
||||||
std::vector<TensorInfo> GetOutputInfos() override;
|
std::vector<TensorInfo> GetOutputInfos() override;
|
||||||
bool Infer(std::vector<FDTensor>& inputs,
|
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
|
||||||
std::vector<FDTensor>* outputs,
|
|
||||||
bool copy_to_fd = true) override;
|
bool copy_to_fd = true) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@@ -24,9 +24,9 @@ typedef enum _rknpu2_cpu_name {
|
|||||||
/*! RKNPU2 core mask for mobile device. */
|
/*! RKNPU2 core mask for mobile device. */
|
||||||
typedef enum _rknpu2_core_mask {
|
typedef enum _rknpu2_core_mask {
|
||||||
RKNN_NPU_CORE_AUTO = 0, //< default, run on NPU core randomly.
|
RKNN_NPU_CORE_AUTO = 0, //< default, run on NPU core randomly.
|
||||||
RKNN_NPU_CORE_0 = 1, //< run on NPU core 0.
|
RKNN_NPU_CORE_0 = 1, //< run on NPU core 0.
|
||||||
RKNN_NPU_CORE_1 = 2, //< run on NPU core 1.
|
RKNN_NPU_CORE_1 = 2, //< run on NPU core 1.
|
||||||
RKNN_NPU_CORE_2 = 4, //< run on NPU core 2.
|
RKNN_NPU_CORE_2 = 4, //< run on NPU core 2.
|
||||||
RKNN_NPU_CORE_0_1 =
|
RKNN_NPU_CORE_0_1 =
|
||||||
RKNN_NPU_CORE_0 | RKNN_NPU_CORE_1, //< run on NPU core 1 and core 2.
|
RKNN_NPU_CORE_0 | RKNN_NPU_CORE_1, //< run on NPU core 1 and core 2.
|
||||||
RKNN_NPU_CORE_0_1_2 =
|
RKNN_NPU_CORE_0_1_2 =
|
||||||
|
@@ -17,108 +17,106 @@
|
|||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
nvinfer1::PluginFieldCollection AdaptivePool2dPluginCreator::mFC{};
|
nvinfer1::PluginFieldCollection AdaptivePool2dPluginCreator::mFC{};
|
||||||
std::vector<nvinfer1::PluginField> AdaptivePool2dPluginCreator::mPluginAttributes;
|
std::vector<nvinfer1::PluginField>
|
||||||
|
AdaptivePool2dPluginCreator::mPluginAttributes;
|
||||||
|
|
||||||
pluginStatus_t AdaptivePool2dInference(cudaStream_t stream, int32_t n, const void* input, void* output);
|
pluginStatus_t AdaptivePool2dInference(cudaStream_t stream, int32_t n,
|
||||||
|
const void* input, void* output);
|
||||||
|
|
||||||
AdaptivePool2d::AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type) {
|
AdaptivePool2d::AdaptivePool2d(std::vector<int32_t> output_size,
|
||||||
|
std::string pooling_type) {
|
||||||
output_size_ = output_size;
|
output_size_ = output_size;
|
||||||
pooling_type_ = pooling_type;
|
pooling_type_ = pooling_type;
|
||||||
}
|
}
|
||||||
|
|
||||||
AdaptivePool2d::AdaptivePool2d(const void* buffer, size_t length) {
|
AdaptivePool2d::AdaptivePool2d(const void* buffer, size_t length) {
|
||||||
const char *d = reinterpret_cast<const char*>(buffer), *a = d;
|
const char *d = reinterpret_cast<const char*>(buffer), *a = d;
|
||||||
output_size_.resize(4);
|
output_size_.resize(4);
|
||||||
for(int64_t i =0 ; i < 4; i++){
|
for (int64_t i = 0; i < 4; i++) {
|
||||||
output_size_[i] =read<int32_t>(d);
|
output_size_[i] = read<int32_t>(d);
|
||||||
}
|
}
|
||||||
if(read<int32_t>(d) == 0){
|
if (read<int32_t>(d) == 0) {
|
||||||
pooling_type_ = "avg";
|
pooling_type_ = "avg";
|
||||||
}else{
|
} else {
|
||||||
pooling_type_ = "max";
|
pooling_type_ = "max";
|
||||||
}
|
}
|
||||||
FDASSERT(d == a + length, "deserialize failed.");
|
FDASSERT(d == a + length, "deserialize failed.");
|
||||||
}
|
}
|
||||||
|
|
||||||
int AdaptivePool2d::getNbOutputs() const noexcept {
|
int AdaptivePool2d::getNbOutputs() const noexcept { return 1; }
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
nvinfer1::DimsExprs AdaptivePool2d::getOutputDimensions(
|
nvinfer1::DimsExprs AdaptivePool2d::getOutputDimensions(
|
||||||
int outputIndex, const nvinfer1::DimsExprs* inputs,
|
int outputIndex, const nvinfer1::DimsExprs* inputs, int nbInputs,
|
||||||
int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept {
|
nvinfer1::IExprBuilder& exprBuilder) noexcept {
|
||||||
try {
|
try {
|
||||||
nvinfer1::DimsExprs output(inputs[0]);
|
nvinfer1::DimsExprs output(inputs[0]);
|
||||||
output.d[2] = exprBuilder.constant(static_cast<int32_t>(output_size_[2]));
|
output.d[2] = exprBuilder.constant(static_cast<int32_t>(output_size_[2]));
|
||||||
output.d[3] = exprBuilder.constant(static_cast<int32_t>(output_size_[3]));
|
output.d[3] = exprBuilder.constant(static_cast<int32_t>(output_size_[3]));
|
||||||
return output;
|
return output;
|
||||||
}
|
} catch (const std::exception& e) {
|
||||||
catch (const std::exception& e) {
|
FDASSERT(false, "getOutputDimensions failed: %s.", e.what());
|
||||||
FDASSERT(false, "getOutputDimensions failed: %s.",e.what());
|
|
||||||
}
|
}
|
||||||
return nvinfer1::DimsExprs{};
|
return nvinfer1::DimsExprs{};
|
||||||
}
|
}
|
||||||
|
|
||||||
int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
int AdaptivePool2d::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||||
const nvinfer1::PluginTensorDesc* outputDesc,
|
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||||
const void* const* inputs,
|
const void* const* inputs, void* const* outputs,
|
||||||
void* const* outputs,
|
void* workspace, cudaStream_t stream) noexcept {
|
||||||
void* workspace,
|
|
||||||
cudaStream_t stream) noexcept {
|
|
||||||
if (inputDesc[0].type != nvinfer1::DataType::kFLOAT) {
|
if (inputDesc[0].type != nvinfer1::DataType::kFLOAT) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
auto const* data = static_cast<float const*>(inputs[0]);
|
auto const* data = static_cast<float const*>(inputs[0]);
|
||||||
auto* result = static_cast<float*>(outputs[0]);
|
auto* result = static_cast<float*>(outputs[0]);
|
||||||
int nums = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] * outputDesc[0].dims.d[2]* outputDesc[0].dims.d[3];
|
int nums = outputDesc[0].dims.d[0] * outputDesc[0].dims.d[1] *
|
||||||
|
outputDesc[0].dims.d[2] * outputDesc[0].dims.d[3];
|
||||||
std::vector<int64_t> input_size, output_size;
|
std::vector<int64_t> input_size, output_size;
|
||||||
for(int i =0; i< 4; i++){
|
for (int i = 0; i < 4; i++) {
|
||||||
input_size.push_back(inputDesc[0].dims.d[i]);
|
input_size.push_back(inputDesc[0].dims.d[i]);
|
||||||
output_size.push_back(outputDesc[0].dims.d[i]);
|
output_size.push_back(outputDesc[0].dims.d[i]);
|
||||||
}
|
}
|
||||||
CudaAdaptivePool(input_size, output_size, result, data, stream, pooling_type_);
|
CudaAdaptivePool(input_size, output_size, result, data, stream,
|
||||||
|
pooling_type_);
|
||||||
return cudaPeekAtLastError();
|
return cudaPeekAtLastError();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t AdaptivePool2d::getSerializationSize() const noexcept {
|
size_t AdaptivePool2d::getSerializationSize() const noexcept {
|
||||||
return 5 * sizeof(int32_t) ;
|
return 5 * sizeof(int32_t);
|
||||||
}
|
}
|
||||||
|
|
||||||
void AdaptivePool2d::serialize(void* buffer) const noexcept {
|
void AdaptivePool2d::serialize(void* buffer) const noexcept {
|
||||||
char *d = reinterpret_cast<char*>(buffer), *a = d;
|
char *d = reinterpret_cast<char*>(buffer), *a = d;
|
||||||
for(int64_t i=0; i< 4; i++){
|
for (int64_t i = 0; i < 4; i++) {
|
||||||
write(d, output_size_[i]);
|
write(d, output_size_[i]);
|
||||||
}
|
}
|
||||||
int32_t pooling_type_val = 0;
|
int32_t pooling_type_val = 0;
|
||||||
if(pooling_type_ != "avg"){
|
if (pooling_type_ != "avg") {
|
||||||
pooling_type_val = 1;
|
pooling_type_val = 1;
|
||||||
}
|
}
|
||||||
write(d, pooling_type_val);
|
write(d, pooling_type_val);
|
||||||
FDASSERT(d == a + getSerializationSize(), "d == a + getSerializationSize()");
|
FDASSERT(d == a + getSerializationSize(), "d == a + getSerializationSize()");
|
||||||
}
|
}
|
||||||
|
|
||||||
nvinfer1::DataType AdaptivePool2d::getOutputDataType(
|
nvinfer1::DataType
|
||||||
int index, const nvinfer1::DataType* inputType, int nbInputs) const noexcept {
|
AdaptivePool2d::getOutputDataType(int index,
|
||||||
|
const nvinfer1::DataType* inputType,
|
||||||
|
int nbInputs) const noexcept {
|
||||||
return inputType[0];
|
return inputType[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AdaptivePool2d::supportsFormatCombination(
|
bool AdaptivePool2d::supportsFormatCombination(
|
||||||
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs, int nbOutputs) noexcept {
|
int pos, const nvinfer1::PluginTensorDesc* inOut, int nbInputs,
|
||||||
|
int nbOutputs) noexcept {
|
||||||
return (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
|
return (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
|
||||||
}
|
}
|
||||||
|
|
||||||
int AdaptivePool2d::initialize() noexcept {
|
int AdaptivePool2d::initialize() noexcept { return 0; }
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
void AdaptivePool2d::terminate() noexcept {
|
void AdaptivePool2d::terminate() noexcept { return; }
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
size_t AdaptivePool2d::getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
size_t AdaptivePool2d::getWorkspaceSize(
|
||||||
int nbInputs,
|
const nvinfer1::PluginTensorDesc* inputs, int nbInputs,
|
||||||
const nvinfer1::PluginTensorDesc* outputs,
|
const nvinfer1::PluginTensorDesc* outputs, int nbOutputs) const noexcept {
|
||||||
int nbOutputs) const noexcept {
|
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,33 +124,32 @@ const char* AdaptivePool2d::getPluginType() const noexcept {
|
|||||||
return "AdaptivePool2d";
|
return "AdaptivePool2d";
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* AdaptivePool2d::getPluginVersion() const noexcept {
|
const char* AdaptivePool2d::getPluginVersion() const noexcept { return "1"; }
|
||||||
return "1";
|
|
||||||
}
|
|
||||||
|
|
||||||
void AdaptivePool2d::destroy() noexcept {
|
void AdaptivePool2d::destroy() noexcept { return; }
|
||||||
|
void AdaptivePool2d::configurePlugin(
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
||||||
|
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
void AdaptivePool2d::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int nbInputs,
|
|
||||||
const nvinfer1::DynamicPluginTensorDesc* out, int nbOutputs) noexcept {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
nvinfer1::IPluginV2DynamicExt* AdaptivePool2d::clone() const noexcept {
|
nvinfer1::IPluginV2DynamicExt* AdaptivePool2d::clone() const noexcept {
|
||||||
try{
|
try {
|
||||||
nvinfer1::IPluginV2DynamicExt* plugin = new AdaptivePool2d(output_size_, pooling_type_);
|
nvinfer1::IPluginV2DynamicExt* plugin =
|
||||||
plugin->setPluginNamespace(mNamespace.c_str());
|
new AdaptivePool2d(output_size_, pooling_type_);
|
||||||
return plugin;
|
plugin->setPluginNamespace(mNamespace.c_str());
|
||||||
}
|
return plugin;
|
||||||
catch (std::exception const& e){
|
} catch (std::exception const& e) {
|
||||||
FDASSERT(false, "clone failed: %s.",e.what());
|
FDASSERT(false, "clone failed: %s.", e.what());
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
AdaptivePool2dPluginCreator::AdaptivePool2dPluginCreator() {
|
AdaptivePool2dPluginCreator::AdaptivePool2dPluginCreator() {
|
||||||
mPluginAttributes.clear();
|
mPluginAttributes.clear();
|
||||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("output_size", nullptr, nvinfer1::PluginFieldType::kINT32, 4));
|
mPluginAttributes.emplace_back(nvinfer1::PluginField(
|
||||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("pooling_type", nullptr, nvinfer1::PluginFieldType::kCHAR, 3));
|
"output_size", nullptr, nvinfer1::PluginFieldType::kINT32, 4));
|
||||||
|
mPluginAttributes.emplace_back(nvinfer1::PluginField(
|
||||||
|
"pooling_type", nullptr, nvinfer1::PluginFieldType::kCHAR, 3));
|
||||||
|
|
||||||
mFC.nbFields = mPluginAttributes.size();
|
mFC.nbFields = mPluginAttributes.size();
|
||||||
mFC.fields = mPluginAttributes.data();
|
mFC.fields = mPluginAttributes.data();
|
||||||
@@ -166,17 +163,18 @@ const char* AdaptivePool2dPluginCreator::getPluginVersion() const noexcept {
|
|||||||
return "1";
|
return "1";
|
||||||
}
|
}
|
||||||
|
|
||||||
const nvinfer1::PluginFieldCollection* AdaptivePool2dPluginCreator::getFieldNames() noexcept {
|
const nvinfer1::PluginFieldCollection*
|
||||||
|
AdaptivePool2dPluginCreator::getFieldNames() noexcept {
|
||||||
return &mFC;
|
return &mFC;
|
||||||
}
|
}
|
||||||
|
|
||||||
nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::createPlugin(const char* name,
|
nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::createPlugin(
|
||||||
const nvinfer1::PluginFieldCollection* fc) noexcept {
|
const char* name, const nvinfer1::PluginFieldCollection* fc) noexcept {
|
||||||
try{
|
try {
|
||||||
const nvinfer1::PluginField* fields = fc->fields;
|
const nvinfer1::PluginField* fields = fc->fields;
|
||||||
auto const dims = static_cast<int32_t const*>(fields[0].data);
|
auto const dims = static_cast<int32_t const*>(fields[0].data);
|
||||||
output_size_.resize(4);
|
output_size_.resize(4);
|
||||||
for(int64_t i = 0; i < 4; i++){
|
for (int64_t i = 0; i < 4; i++) {
|
||||||
output_size_[i] = dims[i];
|
output_size_[i] = dims[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -184,23 +182,20 @@ nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::createPlugin(const c
|
|||||||
std::string pooling_type(pooling_type_ptr, 3);
|
std::string pooling_type(pooling_type_ptr, 3);
|
||||||
pooling_type_ = pooling_type;
|
pooling_type_ = pooling_type;
|
||||||
return new AdaptivePool2d(output_size_, pooling_type_);
|
return new AdaptivePool2d(output_size_, pooling_type_);
|
||||||
}
|
} catch (std::exception const& e) {
|
||||||
catch (std::exception const& e){
|
FDASSERT(false, "createPlugin failed: %s.", e.what());
|
||||||
FDASSERT(false, "createPlugin failed: %s.",e.what());
|
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::deserializePlugin(const char* name,
|
nvinfer1::IPluginV2DynamicExt* AdaptivePool2dPluginCreator::deserializePlugin(
|
||||||
const void* serialData,
|
const char* name, const void* serialData, size_t serialLength) noexcept {
|
||||||
size_t serialLength) noexcept {
|
try {
|
||||||
try{
|
|
||||||
return new AdaptivePool2d(serialData, serialLength);
|
return new AdaptivePool2d(serialData, serialLength);
|
||||||
}
|
} catch (std::exception const& e) {
|
||||||
catch (std::exception const& e){
|
FDASSERT(false, "deserializePlugin failed: %s.", e.what());
|
||||||
FDASSERT(false, "deserializePlugin failed: %s.",e.what());
|
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -13,98 +13,93 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include "common.h" // NOLINT
|
||||||
#include "fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h"
|
#include "fastdeploy/backends/op_cuda_kernels/adaptive_pool2d_kernel.h"
|
||||||
#include "common.h" // NOLINT
|
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
class AdaptivePool2d : public BasePlugin {
|
class AdaptivePool2d : public BasePlugin {
|
||||||
public:
|
public:
|
||||||
AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type);
|
AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type);
|
||||||
|
|
||||||
AdaptivePool2d(const void* buffer, size_t length);
|
AdaptivePool2d(const void* buffer, size_t length);
|
||||||
|
|
||||||
~AdaptivePool2d() override = default;
|
~AdaptivePool2d() override = default;
|
||||||
|
|
||||||
int getNbOutputs() const noexcept override;
|
int getNbOutputs() const noexcept override;
|
||||||
|
|
||||||
nvinfer1::DimsExprs getOutputDimensions(
|
nvinfer1::DimsExprs
|
||||||
int outputIndex,
|
getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs,
|
||||||
const nvinfer1::DimsExprs* inputs,
|
int nbInputs,
|
||||||
int nbInputs,
|
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
|
||||||
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
|
|
||||||
|
|
||||||
nvinfer1::DataType getOutputDataType(
|
nvinfer1::DataType getOutputDataType(int index,
|
||||||
int index,
|
const nvinfer1::DataType* inputType,
|
||||||
const nvinfer1::DataType* inputType,
|
int nbInputs) const noexcept override;
|
||||||
int nbInputs) const noexcept override;
|
|
||||||
|
|
||||||
bool supportsFormatCombination(
|
bool supportsFormatCombination(int pos,
|
||||||
int pos,
|
const nvinfer1::PluginTensorDesc* inOut,
|
||||||
const nvinfer1::PluginTensorDesc* inOut,
|
int nbInputs, int nbOutputs) noexcept override;
|
||||||
int nbInputs,
|
|
||||||
int nbOutputs) noexcept override;
|
|
||||||
|
|
||||||
int initialize() noexcept override;
|
int initialize() noexcept override;
|
||||||
|
|
||||||
void terminate() noexcept override;
|
void terminate() noexcept override;
|
||||||
|
|
||||||
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
|
||||||
int nbInputs,
|
int nbInputs,
|
||||||
const nvinfer1::PluginTensorDesc* outputs,
|
const nvinfer1::PluginTensorDesc* outputs,
|
||||||
int nbOutputs) const noexcept override;
|
int nbOutputs) const noexcept override;
|
||||||
|
|
||||||
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
|
||||||
const nvinfer1::PluginTensorDesc* outputDesc,
|
const nvinfer1::PluginTensorDesc* outputDesc,
|
||||||
const void* const* inputs,
|
const void* const* inputs, void* const* outputs, void* workspace,
|
||||||
void* const* outputs,
|
cudaStream_t stream) noexcept override;
|
||||||
void* workspace,
|
|
||||||
cudaStream_t stream) noexcept override;
|
|
||||||
|
|
||||||
size_t getSerializationSize() const noexcept override;
|
size_t getSerializationSize() const noexcept override;
|
||||||
|
|
||||||
void serialize(void* buffer) const noexcept override;
|
void serialize(void* buffer) const noexcept override;
|
||||||
|
|
||||||
const char* getPluginType() const noexcept override;
|
const char* getPluginType() const noexcept override;
|
||||||
|
|
||||||
const char* getPluginVersion() const noexcept override;
|
const char* getPluginVersion() const noexcept override;
|
||||||
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
|
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
|
||||||
int nbInputs,
|
int nbInputs,
|
||||||
const nvinfer1::DynamicPluginTensorDesc* out,
|
const nvinfer1::DynamicPluginTensorDesc* out,
|
||||||
int nbOutputs) noexcept override;
|
int nbOutputs) noexcept override;
|
||||||
void destroy() noexcept override;
|
void destroy() noexcept override;
|
||||||
|
|
||||||
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int32_t> output_size_;
|
std::vector<int32_t> output_size_;
|
||||||
std::string pooling_type_;
|
std::string pooling_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class AdaptivePool2dPluginCreator : public BaseCreator {
|
class AdaptivePool2dPluginCreator : public BaseCreator {
|
||||||
public:
|
public:
|
||||||
AdaptivePool2dPluginCreator();
|
AdaptivePool2dPluginCreator();
|
||||||
|
|
||||||
~AdaptivePool2dPluginCreator() override = default;
|
~AdaptivePool2dPluginCreator() override = default;
|
||||||
|
|
||||||
const char* getPluginName() const noexcept override;
|
const char* getPluginName() const noexcept override;
|
||||||
|
|
||||||
const char* getPluginVersion() const noexcept override;
|
const char* getPluginVersion() const noexcept override;
|
||||||
|
|
||||||
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
|
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
|
||||||
|
|
||||||
nvinfer1::IPluginV2DynamicExt* createPlugin(const char* name,
|
nvinfer1::IPluginV2DynamicExt*
|
||||||
const nvinfer1::PluginFieldCollection* fc) noexcept override;
|
createPlugin(const char* name,
|
||||||
|
const nvinfer1::PluginFieldCollection* fc) noexcept override;
|
||||||
|
|
||||||
nvinfer1::IPluginV2DynamicExt* deserializePlugin(const char* name,
|
nvinfer1::IPluginV2DynamicExt*
|
||||||
const void* serialData,
|
deserializePlugin(const char* name, const void* serialData,
|
||||||
size_t serialLength) noexcept override;
|
size_t serialLength) noexcept override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static nvinfer1::PluginFieldCollection mFC;
|
static nvinfer1::PluginFieldCollection mFC;
|
||||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||||
std::vector<int32_t> output_size_;
|
std::vector<int32_t> output_size_;
|
||||||
std::string pooling_type_;
|
std::string pooling_type_;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_TENSORRT_PLUGIN(AdaptivePool2dPluginCreator);
|
REGISTER_TENSORRT_PLUGIN(AdaptivePool2dPluginCreator);
|
||||||
|
@@ -17,40 +17,40 @@
|
|||||||
#include "NvInferPlugin.h"
|
#include "NvInferPlugin.h"
|
||||||
#include "NvInferRuntimeCommon.h"
|
#include "NvInferRuntimeCommon.h"
|
||||||
#include "fastdeploy/utils/utils.h"
|
#include "fastdeploy/utils/utils.h"
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
#include <memory>
|
||||||
|
#include <sstream>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
|
||||||
#include <cstring>
|
|
||||||
#include <sstream>
|
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
class BasePlugin : public nvinfer1::IPluginV2DynamicExt {
|
class BasePlugin : public nvinfer1::IPluginV2DynamicExt {
|
||||||
protected:
|
protected:
|
||||||
void setPluginNamespace(const char* libNamespace) noexcept override {
|
void setPluginNamespace(const char* libNamespace) noexcept override {
|
||||||
mNamespace = libNamespace;
|
mNamespace = libNamespace;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* getPluginNamespace() const noexcept override {
|
const char* getPluginNamespace() const noexcept override {
|
||||||
return mNamespace.c_str();
|
return mNamespace.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string mNamespace;
|
std::string mNamespace;
|
||||||
};
|
};
|
||||||
|
|
||||||
class BaseCreator : public nvinfer1::IPluginCreator {
|
class BaseCreator : public nvinfer1::IPluginCreator {
|
||||||
public:
|
public:
|
||||||
void setPluginNamespace(const char* libNamespace) noexcept override {
|
void setPluginNamespace(const char* libNamespace) noexcept override {
|
||||||
mNamespace = libNamespace;
|
mNamespace = libNamespace;
|
||||||
}
|
}
|
||||||
|
|
||||||
const char* getPluginNamespace() const noexcept override {
|
const char* getPluginNamespace() const noexcept override {
|
||||||
return mNamespace.c_str();
|
return mNamespace.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::string mNamespace;
|
std::string mNamespace;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef enum {
|
typedef enum {
|
||||||
@@ -62,19 +62,17 @@ typedef enum {
|
|||||||
} pluginStatus_t;
|
} pluginStatus_t;
|
||||||
|
|
||||||
// Write values into buffer
|
// Write values into buffer
|
||||||
template <typename T>
|
template <typename T> void write(char*& buffer, const T& val) {
|
||||||
void write(char*& buffer, const T& val) {
|
std::memcpy(buffer, &val, sizeof(T));
|
||||||
std::memcpy(buffer, &val, sizeof(T));
|
buffer += sizeof(T);
|
||||||
buffer += sizeof(T);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read values from buffer
|
// Read values from buffer
|
||||||
template <typename T>
|
template <typename T> T read(const char*& buffer) {
|
||||||
T read(const char*& buffer) {
|
T val{};
|
||||||
T val{};
|
std::memcpy(&val, buffer, sizeof(T));
|
||||||
std::memcpy(&val, buffer, sizeof(T));
|
buffer += sizeof(T);
|
||||||
buffer += sizeof(T);
|
return val;
|
||||||
return val;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -134,9 +134,9 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
int calibration_cache_size = 0;
|
int calibration_cache_size = 0;
|
||||||
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
|
if (!paddle2onnx::Export(model_file.c_str(), params_file.c_str(),
|
||||||
&model_content_ptr, &model_content_size, 11, true,
|
&model_content_ptr, &model_content_size, 11, true,
|
||||||
verbose, true, true, true, ops.data(),
|
verbose, true, true, true, ops.data(), 1, "tensorrt",
|
||||||
1, "tensorrt",
|
&calibration_cache_ptr, &calibration_cache_size, "",
|
||||||
&calibration_cache_ptr, &calibration_cache_size, "", &save_external_)) {
|
&save_external_)) {
|
||||||
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
FDERROR << "Error occured while export PaddlePaddle to ONNX format."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
return false;
|
return false;
|
||||||
@@ -152,11 +152,11 @@ bool TrtBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
calibration_str_ = calibration_str;
|
calibration_str_ = calibration_str;
|
||||||
delete[] calibration_cache_ptr;
|
delete[] calibration_cache_ptr;
|
||||||
}
|
}
|
||||||
if(save_external_){
|
if (save_external_) {
|
||||||
model_file_name_ = "model.onnx";
|
model_file_name_ = "model.onnx";
|
||||||
std::fstream f(model_file_name_, std::ios::out);
|
std::fstream f(model_file_name_, std::ios::out);
|
||||||
FDASSERT(f.is_open(), "Can not open file: %s to save model.",
|
FDASSERT(f.is_open(), "Can not open file: %s to save model.",
|
||||||
model_file_name_.c_str());
|
model_file_name_.c_str());
|
||||||
f << onnx_model_proto;
|
f << onnx_model_proto;
|
||||||
f.close();
|
f.close();
|
||||||
return InitFromOnnx(model_file_name_, option, false);
|
return InitFromOnnx(model_file_name_, option, false);
|
||||||
@@ -215,13 +215,14 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
|||||||
outputs_desc_.resize(onnx_reader.num_outputs);
|
outputs_desc_.resize(onnx_reader.num_outputs);
|
||||||
for (int i = 0; i < onnx_reader.num_inputs; ++i) {
|
for (int i = 0; i < onnx_reader.num_inputs; ++i) {
|
||||||
std::string name(onnx_reader.inputs[i].name);
|
std::string name(onnx_reader.inputs[i].name);
|
||||||
std::vector<int64_t> shape(
|
std::vector<int64_t> shape(onnx_reader.inputs[i].shape,
|
||||||
onnx_reader.inputs[i].shape,
|
onnx_reader.inputs[i].shape +
|
||||||
onnx_reader.inputs[i].shape + onnx_reader.inputs[i].rank);
|
onnx_reader.inputs[i].rank);
|
||||||
inputs_desc_[i].name = name;
|
inputs_desc_[i].name = name;
|
||||||
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||||
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
|
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
|
||||||
inputs_desc_[i].original_dtype = ReaderDtypeToFDDtype(onnx_reader.inputs[i].dtype);
|
inputs_desc_[i].original_dtype =
|
||||||
|
ReaderDtypeToFDDtype(onnx_reader.inputs[i].dtype);
|
||||||
auto info = ShapeRangeInfo(shape);
|
auto info = ShapeRangeInfo(shape);
|
||||||
info.name = name;
|
info.name = name;
|
||||||
auto iter_min = option.min_shape.find(name);
|
auto iter_min = option.min_shape.find(name);
|
||||||
@@ -237,9 +238,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
|||||||
|
|
||||||
for (int i = 0; i < onnx_reader.num_outputs; ++i) {
|
for (int i = 0; i < onnx_reader.num_outputs; ++i) {
|
||||||
std::string name(onnx_reader.outputs[i].name);
|
std::string name(onnx_reader.outputs[i].name);
|
||||||
std::vector<int64_t> shape(
|
std::vector<int64_t> shape(onnx_reader.outputs[i].shape,
|
||||||
onnx_reader.outputs[i].shape,
|
onnx_reader.outputs[i].shape +
|
||||||
onnx_reader.outputs[i].shape + onnx_reader.outputs[i].rank);
|
onnx_reader.outputs[i].rank);
|
||||||
outputs_desc_[i].name = name;
|
outputs_desc_[i].name = name;
|
||||||
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||||
outputs_desc_[i].dtype =
|
outputs_desc_[i].dtype =
|
||||||
@@ -252,10 +253,10 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
|||||||
stream_ = reinterpret_cast<cudaStream_t>(option_.external_stream_);
|
stream_ = reinterpret_cast<cudaStream_t>(option_.external_stream_);
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(cudaStreamCreate(&stream_) == 0,
|
FDASSERT(cudaStreamCreate(&stream_) == 0,
|
||||||
"[ERROR] Error occurs while calling cudaStreamCreate().");
|
"[ERROR] Error occurs while calling cudaStreamCreate().");
|
||||||
}
|
}
|
||||||
|
|
||||||
if(save_external_){
|
if (save_external_) {
|
||||||
onnx_content.clear();
|
onnx_content.clear();
|
||||||
onnx_content = model_file_name_;
|
onnx_content = model_file_name_;
|
||||||
}
|
}
|
||||||
@@ -283,8 +284,7 @@ int TrtBackend::ShapeRangeInfoUpdated(const std::vector<FDTensor>& inputs) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||||
std::vector<FDTensor>* outputs,
|
std::vector<FDTensor>* outputs, bool copy_to_fd) {
|
||||||
bool copy_to_fd) {
|
|
||||||
if (inputs.size() != NumInputs()) {
|
if (inputs.size() != NumInputs()) {
|
||||||
FDERROR << "Require " << NumInputs() << "inputs, but get " << inputs.size()
|
FDERROR << "Require " << NumInputs() << "inputs, but get " << inputs.size()
|
||||||
<< "." << std::endl;
|
<< "." << std::endl;
|
||||||
@@ -297,7 +297,8 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
<< "TensorRT engine will be rebuilt once shape range information "
|
<< "TensorRT engine will be rebuilt once shape range information "
|
||||||
"changed, this may take lots of time, you can set a proper shape "
|
"changed, this may take lots of time, you can set a proper shape "
|
||||||
"range before loading model to avoid rebuilding process. refer "
|
"range before loading model to avoid rebuilding process. refer "
|
||||||
"https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/en/faq/"
|
"https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/en/"
|
||||||
|
"faq/"
|
||||||
"tensorrt_tricks.md for more details."
|
"tensorrt_tricks.md for more details."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
BuildTrtEngine();
|
BuildTrtEngine();
|
||||||
@@ -314,38 +315,42 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||||
// if the final output tensor's dtype is different from the model output tensor's dtype,
|
// if the final output tensor's dtype is different from the model output tensor's dtype,
|
||||||
// then we need cast the data to the final output's dtype
|
// then we need cast the data to the final output's dtype
|
||||||
auto model_output_dtype = GetFDDataType(outputs_device_buffer_[(*outputs)[i].name].dtype());
|
auto model_output_dtype =
|
||||||
|
GetFDDataType(outputs_device_buffer_[(*outputs)[i].name].dtype());
|
||||||
if ((*outputs)[i].dtype != model_output_dtype) {
|
if ((*outputs)[i].dtype != model_output_dtype) {
|
||||||
FDTensor output_tensor;
|
FDTensor output_tensor;
|
||||||
output_tensor.SetExternalData((*outputs)[i].shape, model_output_dtype,
|
output_tensor.SetExternalData(
|
||||||
outputs_device_buffer_[(*outputs)[i].name].data(),
|
(*outputs)[i].shape, model_output_dtype,
|
||||||
Device::GPU);
|
outputs_device_buffer_[(*outputs)[i].name].data(), Device::GPU);
|
||||||
|
|
||||||
casted_output_tensors_[(*outputs)[i].name].Resize((*outputs)[i].shape, (*outputs)[i].dtype,
|
casted_output_tensors_[(*outputs)[i].name].Resize(
|
||||||
(*outputs)[i].name, Device::GPU);
|
(*outputs)[i].shape, (*outputs)[i].dtype, (*outputs)[i].name,
|
||||||
function::CudaCast(output_tensor, &casted_output_tensors_[(*outputs)[i].name], stream_);
|
Device::GPU);
|
||||||
if(!copy_to_fd) {
|
function::CudaCast(output_tensor,
|
||||||
(*outputs)[i].SetExternalData((*outputs)[i].shape, model_output_dtype,
|
&casted_output_tensors_[(*outputs)[i].name], stream_);
|
||||||
casted_output_tensors_[(*outputs)[i].name].MutableData(),
|
if (!copy_to_fd) {
|
||||||
Device::GPU, option_.gpu_id);
|
(*outputs)[i].SetExternalData(
|
||||||
|
(*outputs)[i].shape, model_output_dtype,
|
||||||
|
casted_output_tensors_[(*outputs)[i].name].MutableData(),
|
||||||
|
Device::GPU, option_.gpu_id);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
casted_output_tensors_[(*outputs)[i].name].SetExternalData(
|
casted_output_tensors_[(*outputs)[i].name].SetExternalData(
|
||||||
(*outputs)[i].shape, model_output_dtype,
|
(*outputs)[i].shape, model_output_dtype,
|
||||||
outputs_device_buffer_[(*outputs)[i].name].data(),
|
outputs_device_buffer_[(*outputs)[i].name].data(), Device::GPU);
|
||||||
Device::GPU);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (copy_to_fd) {
|
if (copy_to_fd) {
|
||||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||||
FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(),
|
FDASSERT(
|
||||||
casted_output_tensors_[(*outputs)[i].name].Data(),
|
cudaMemcpyAsync((*outputs)[i].Data(),
|
||||||
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
|
casted_output_tensors_[(*outputs)[i].name].Data(),
|
||||||
stream_) == 0,
|
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
|
||||||
"[ERROR] Error occurs while copy memory from GPU to CPU.");
|
stream_) == 0,
|
||||||
|
"[ERROR] Error occurs while copy memory from GPU to CPU.");
|
||||||
}
|
}
|
||||||
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
|
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
|
||||||
"[ERROR] Error occurs while sync cuda stream.");
|
"[ERROR] Error occurs while sync cuda stream.");
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@@ -356,10 +361,12 @@ void TrtBackend::GetInputOutputInfo() {
|
|||||||
std::unordered_map<std::string, FDDataType> inputs_original_dtype_map;
|
std::unordered_map<std::string, FDDataType> inputs_original_dtype_map;
|
||||||
std::unordered_map<std::string, FDDataType> outputs_original_dtype_map;
|
std::unordered_map<std::string, FDDataType> outputs_original_dtype_map;
|
||||||
for (size_t i = 0; i < inputs_desc_.size(); ++i) {
|
for (size_t i = 0; i < inputs_desc_.size(); ++i) {
|
||||||
inputs_original_dtype_map[inputs_desc_[i].name] = inputs_desc_[i].original_dtype;
|
inputs_original_dtype_map[inputs_desc_[i].name] =
|
||||||
|
inputs_desc_[i].original_dtype;
|
||||||
}
|
}
|
||||||
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
||||||
outputs_original_dtype_map[outputs_desc_[i].name] = outputs_desc_[i].original_dtype;
|
outputs_original_dtype_map[outputs_desc_[i].name] =
|
||||||
|
outputs_desc_[i].original_dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re-read the tensor infos from TRT model and write into inputs_desc_ and outputs_desc_
|
// Re-read the tensor infos from TRT model and write into inputs_desc_ and outputs_desc_
|
||||||
@@ -373,12 +380,18 @@ void TrtBackend::GetInputOutputInfo() {
|
|||||||
auto shape = ToVec(engine_->getBindingDimensions(i));
|
auto shape = ToVec(engine_->getBindingDimensions(i));
|
||||||
auto dtype = engine_->getBindingDataType(i);
|
auto dtype = engine_->getBindingDataType(i);
|
||||||
if (engine_->bindingIsInput(i)) {
|
if (engine_->bindingIsInput(i)) {
|
||||||
auto original_dtype = inputs_original_dtype_map.count(name) ? inputs_original_dtype_map[name] : GetFDDataType(dtype);
|
auto original_dtype = inputs_original_dtype_map.count(name)
|
||||||
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype, original_dtype});
|
? inputs_original_dtype_map[name]
|
||||||
|
: GetFDDataType(dtype);
|
||||||
|
inputs_desc_.emplace_back(
|
||||||
|
TrtValueInfo{name, shape, dtype, original_dtype});
|
||||||
inputs_device_buffer_[name] = FDDeviceBuffer(dtype);
|
inputs_device_buffer_[name] = FDDeviceBuffer(dtype);
|
||||||
} else {
|
} else {
|
||||||
auto original_dtype = outputs_original_dtype_map.count(name) ? outputs_original_dtype_map[name] : GetFDDataType(dtype);
|
auto original_dtype = outputs_original_dtype_map.count(name)
|
||||||
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype, original_dtype});
|
? outputs_original_dtype_map[name]
|
||||||
|
: GetFDDataType(dtype);
|
||||||
|
outputs_desc_.emplace_back(
|
||||||
|
TrtValueInfo{name, shape, dtype, original_dtype});
|
||||||
outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
|
outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
|
||||||
casted_output_tensors_[name] = FDTensor();
|
casted_output_tensors_[name] = FDTensor();
|
||||||
}
|
}
|
||||||
@@ -391,8 +404,9 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
|
|||||||
for (const auto& item : inputs) {
|
for (const auto& item : inputs) {
|
||||||
// auto idx = engine_->getBindingIndex(item.name.c_str());
|
// auto idx = engine_->getBindingIndex(item.name.c_str());
|
||||||
auto iter = io_name_index_.find(item.name);
|
auto iter = io_name_index_.find(item.name);
|
||||||
FDASSERT(iter != io_name_index_.end(), "TRTBackend SetInputs not find name:%s", item.name.c_str());
|
FDASSERT(iter != io_name_index_.end(),
|
||||||
auto idx = iter->second;
|
"TRTBackend SetInputs not find name:%s", item.name.c_str());
|
||||||
|
auto idx = iter->second;
|
||||||
std::vector<int> shape(item.shape.begin(), item.shape.end());
|
std::vector<int> shape(item.shape.begin(), item.shape.end());
|
||||||
auto dims = ToDims(shape);
|
auto dims = ToDims(shape);
|
||||||
context_->setBindingDimensions(idx, dims);
|
context_->setBindingDimensions(idx, dims);
|
||||||
@@ -424,9 +438,8 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
|
|||||||
"Error occurs while copy memory from CPU to GPU.");
|
"Error occurs while copy memory from CPU to GPU.");
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
|
FDASSERT(cudaMemcpyAsync(inputs_device_buffer_[item.name].data(),
|
||||||
item.Data(),
|
item.Data(), item.Nbytes(),
|
||||||
item.Nbytes(), cudaMemcpyHostToDevice,
|
cudaMemcpyHostToDevice, stream_) == 0,
|
||||||
stream_) == 0,
|
|
||||||
"Error occurs while copy memory from CPU to GPU.");
|
"Error occurs while copy memory from CPU to GPU.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -443,8 +456,10 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs,
|
|||||||
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
||||||
// auto idx = engine_->getBindingIndex(outputs_desc_[i].name.c_str());
|
// auto idx = engine_->getBindingIndex(outputs_desc_[i].name.c_str());
|
||||||
auto idx_iter = io_name_index_.find(outputs_desc_[i].name);
|
auto idx_iter = io_name_index_.find(outputs_desc_[i].name);
|
||||||
FDASSERT(idx_iter != io_name_index_.end(), "TRTBackend Outputs not find name:%s", outputs_desc_[i].name.c_str());
|
FDASSERT(idx_iter != io_name_index_.end(),
|
||||||
auto idx = idx_iter->second;
|
"TRTBackend Outputs not find name:%s",
|
||||||
|
outputs_desc_[i].name.c_str());
|
||||||
|
auto idx = idx_iter->second;
|
||||||
auto output_dims = context_->getBindingDimensions(idx);
|
auto output_dims = context_->getBindingDimensions(idx);
|
||||||
|
|
||||||
// find the original index of output
|
// find the original index of output
|
||||||
@@ -457,23 +472,22 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs,
|
|||||||
|
|
||||||
// Allocate output buffer memory
|
// Allocate output buffer memory
|
||||||
outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims);
|
outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims);
|
||||||
|
|
||||||
// binding output buffer
|
// binding output buffer
|
||||||
bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data();
|
bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data();
|
||||||
|
|
||||||
// set user's outputs info
|
// set user's outputs info
|
||||||
std::vector<int64_t> shape(output_dims.d,
|
std::vector<int64_t> shape(output_dims.d,
|
||||||
output_dims.d + output_dims.nbDims);
|
output_dims.d + output_dims.nbDims);
|
||||||
if(copy_to_fd) {
|
if (copy_to_fd) {
|
||||||
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
|
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
|
||||||
(*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype,
|
(*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype,
|
||||||
outputs_desc_[i].name);
|
outputs_desc_[i].name);
|
||||||
} else {
|
} else {
|
||||||
(*outputs)[ori_idx].name = outputs_desc_[i].name;
|
(*outputs)[ori_idx].name = outputs_desc_[i].name;
|
||||||
(*outputs)[ori_idx].SetExternalData(
|
(*outputs)[ori_idx].SetExternalData(
|
||||||
shape, outputs_desc_[i].original_dtype,
|
shape, outputs_desc_[i].original_dtype, bindings_[idx], Device::GPU,
|
||||||
bindings_[idx], Device::GPU,
|
option_.gpu_id);
|
||||||
option_.gpu_id);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -587,7 +601,8 @@ bool TrtBackend::BuildTrtEngine() {
|
|||||||
if (option_.serialize_file != "") {
|
if (option_.serialize_file != "") {
|
||||||
FDINFO << "Serialize TensorRTEngine to local file "
|
FDINFO << "Serialize TensorRTEngine to local file "
|
||||||
<< option_.serialize_file << "." << std::endl;
|
<< option_.serialize_file << "." << std::endl;
|
||||||
std::ofstream engine_file(option_.serialize_file.c_str(), std::ios::binary | std::ios::out);
|
std::ofstream engine_file(option_.serialize_file.c_str(),
|
||||||
|
std::ios::binary | std::ios::out);
|
||||||
if (!engine_file) {
|
if (!engine_file) {
|
||||||
FDERROR << "Failed to open " << option_.serialize_file << " to write."
|
FDERROR << "Failed to open " << option_.serialize_file << " to write."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
@@ -628,10 +643,11 @@ bool TrtBackend::CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer) {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
bool model_parser;
|
bool model_parser;
|
||||||
if(save_external_){
|
if (save_external_) {
|
||||||
model_parser=!parser_->parseFromFile(onnx_model_buffer.c_str(), 0);
|
model_parser = !parser_->parseFromFile(onnx_model_buffer.c_str(), 0);
|
||||||
}else{
|
} else {
|
||||||
model_parser = !parser_->parse(onnx_model_buffer.data(), onnx_model_buffer.size());
|
model_parser =
|
||||||
|
!parser_->parse(onnx_model_buffer.data(), onnx_model_buffer.size());
|
||||||
}
|
}
|
||||||
if (model_parser) {
|
if (model_parser) {
|
||||||
FDERROR << "Failed to parse ONNX model by TensorRT." << std::endl;
|
FDERROR << "Failed to parse ONNX model by TensorRT." << std::endl;
|
||||||
@@ -665,7 +681,8 @@ bool TrtBackend::CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer) {
|
|||||||
"should be noticed that FastDeploy will rebuild the engine while "
|
"should be noticed that FastDeploy will rebuild the engine while "
|
||||||
"new input shape is out of the collected shape range, this may "
|
"new input shape is out of the collected shape range, this may "
|
||||||
"bring some time consuming problem, refer "
|
"bring some time consuming problem, refer "
|
||||||
"https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/en/faq/"
|
"https://github.com/PaddlePaddle/FastDeploy/blob/develop/docs/en/"
|
||||||
|
"faq/"
|
||||||
"tensorrt_tricks.md for more details."
|
"tensorrt_tricks.md for more details."
|
||||||
<< std::endl;
|
<< std::endl;
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
@@ -721,27 +738,24 @@ std::vector<TensorInfo> TrtBackend::GetOutputInfos() {
|
|||||||
return infos;
|
return infos;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<BaseBackend> TrtBackend::Clone(void *stream, int device_id) {
|
std::unique_ptr<BaseBackend> TrtBackend::Clone(void* stream, int device_id) {
|
||||||
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<TrtBackend>();
|
std::unique_ptr<BaseBackend> new_backend = utils::make_unique<TrtBackend>();
|
||||||
auto casted_backend = dynamic_cast<TrtBackend*>(new_backend.get());
|
auto casted_backend = dynamic_cast<TrtBackend*>(new_backend.get());
|
||||||
if(device_id > 0 && device_id != option_.gpu_id) {
|
if (device_id > 0 && device_id != option_.gpu_id) {
|
||||||
auto clone_option = option_;
|
auto clone_option = option_;
|
||||||
clone_option.gpu_id = device_id;
|
clone_option.gpu_id = device_id;
|
||||||
clone_option.external_stream_ = stream;
|
clone_option.external_stream_ = stream;
|
||||||
if (option_.model_format == ModelFormat::ONNX) {
|
if (option_.model_format == ModelFormat::ONNX) {
|
||||||
FDASSERT(casted_backend->InitFromOnnx(option_.model_file, clone_option),
|
FDASSERT(casted_backend->InitFromOnnx(option_.model_file, clone_option),
|
||||||
"Clone model from ONNX failed while initialize TrtBackend.");
|
"Clone model from ONNX failed while initialize TrtBackend.");
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(casted_backend->InitFromPaddle(option_.model_file,
|
FDASSERT(casted_backend->InitFromPaddle(
|
||||||
option_.params_file, clone_option),
|
option_.model_file, option_.params_file, clone_option),
|
||||||
"Clone model from Paddle failed while initialize TrtBackend.");
|
"Clone model from Paddle failed while initialize TrtBackend.");
|
||||||
}
|
}
|
||||||
FDWARNING << "The target device id:"
|
FDWARNING << "The target device id:" << device_id
|
||||||
<< device_id
|
<< " is different from current device id:" << option_.gpu_id
|
||||||
<< " is different from current device id:"
|
<< ", cannot share memory with current engine." << std::endl;
|
||||||
<< option_.gpu_id
|
|
||||||
<< ", cannot share memory with current engine."
|
|
||||||
<< std::endl;
|
|
||||||
return new_backend;
|
return new_backend;
|
||||||
}
|
}
|
||||||
cudaSetDevice(option_.gpu_id);
|
cudaSetDevice(option_.gpu_id);
|
||||||
@@ -750,12 +764,15 @@ std::unique_ptr<BaseBackend> TrtBackend::Clone(void *stream, int device_id) {
|
|||||||
casted_backend->stream_ = reinterpret_cast<cudaStream_t>(stream);
|
casted_backend->stream_ = reinterpret_cast<cudaStream_t>(stream);
|
||||||
} else {
|
} else {
|
||||||
FDASSERT(cudaStreamCreate(&casted_backend->stream_) == 0,
|
FDASSERT(cudaStreamCreate(&casted_backend->stream_) == 0,
|
||||||
"[ERROR] Error occurs while clone calling cudaStreamCreate().");
|
"[ERROR] Error occurs while clone calling cudaStreamCreate().");
|
||||||
}
|
}
|
||||||
casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end());
|
casted_backend->inputs_desc_.assign(inputs_desc_.begin(), inputs_desc_.end());
|
||||||
casted_backend->outputs_desc_.assign(outputs_desc_.begin(), outputs_desc_.end());
|
casted_backend->outputs_desc_.assign(outputs_desc_.begin(),
|
||||||
casted_backend->outputs_order_.insert(outputs_order_.begin(), outputs_order_.end());
|
outputs_desc_.end());
|
||||||
casted_backend->shape_range_info_.insert(shape_range_info_.begin(), shape_range_info_.end());
|
casted_backend->outputs_order_.insert(outputs_order_.begin(),
|
||||||
|
outputs_order_.end());
|
||||||
|
casted_backend->shape_range_info_.insert(shape_range_info_.begin(),
|
||||||
|
shape_range_info_.end());
|
||||||
casted_backend->engine_ = engine_;
|
casted_backend->engine_ = engine_;
|
||||||
casted_backend->context_ = std::shared_ptr<nvinfer1::IExecutionContext>(
|
casted_backend->context_ = std::shared_ptr<nvinfer1::IExecutionContext>(
|
||||||
casted_backend->engine_->createExecutionContext());
|
casted_backend->engine_->createExecutionContext());
|
||||||
|
@@ -58,7 +58,7 @@ namespace fastdeploy {
|
|||||||
struct TrtValueInfo {
|
struct TrtValueInfo {
|
||||||
std::string name;
|
std::string name;
|
||||||
std::vector<int> shape;
|
std::vector<int> shape;
|
||||||
nvinfer1::DataType dtype; // dtype of TRT model
|
nvinfer1::DataType dtype; // dtype of TRT model
|
||||||
FDDataType original_dtype; // dtype of original ONNX/Paddle model
|
FDDataType original_dtype; // dtype of original ONNX/Paddle model
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -97,8 +97,7 @@ class TrtBackend : public BaseBackend {
|
|||||||
bool InitFromOnnx(const std::string& model_file,
|
bool InitFromOnnx(const std::string& model_file,
|
||||||
const TrtBackendOption& option = TrtBackendOption(),
|
const TrtBackendOption& option = TrtBackendOption(),
|
||||||
bool from_memory_buffer = false);
|
bool from_memory_buffer = false);
|
||||||
bool Infer(std::vector<FDTensor>& inputs,
|
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
|
||||||
std::vector<FDTensor>* outputs,
|
|
||||||
bool copy_to_fd = true) override;
|
bool copy_to_fd = true) override;
|
||||||
|
|
||||||
int NumInputs() const { return inputs_desc_.size(); }
|
int NumInputs() const { return inputs_desc_.size(); }
|
||||||
@@ -107,7 +106,7 @@ class TrtBackend : public BaseBackend {
|
|||||||
TensorInfo GetOutputInfo(int index);
|
TensorInfo GetOutputInfo(int index);
|
||||||
std::vector<TensorInfo> GetInputInfos() override;
|
std::vector<TensorInfo> GetInputInfos() override;
|
||||||
std::vector<TensorInfo> GetOutputInfos() override;
|
std::vector<TensorInfo> GetOutputInfos() override;
|
||||||
std::unique_ptr<BaseBackend> Clone(void *stream = nullptr,
|
std::unique_ptr<BaseBackend> Clone(void* stream = nullptr,
|
||||||
int device_id = -1) override;
|
int device_id = -1) override;
|
||||||
|
|
||||||
~TrtBackend() {
|
~TrtBackend() {
|
||||||
|
@@ -32,17 +32,15 @@
|
|||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
struct FDInferDeleter {
|
struct FDInferDeleter {
|
||||||
template <typename T>
|
template <typename T> void operator()(T* obj) const {
|
||||||
void operator()(T* obj) const {
|
|
||||||
if (obj) {
|
if (obj) {
|
||||||
delete obj;
|
delete obj;
|
||||||
// obj->destroy();
|
// obj->destroy();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T> using FDUniquePtr = std::unique_ptr<T, FDInferDeleter>;
|
||||||
using FDUniquePtr = std::unique_ptr<T, FDInferDeleter>;
|
|
||||||
|
|
||||||
int64_t Volume(const nvinfer1::Dims& d);
|
int64_t Volume(const nvinfer1::Dims& d);
|
||||||
|
|
||||||
@@ -72,17 +70,13 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& vec) {
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename AllocFunc, typename FreeFunc>
|
template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
|
||||||
class FDGenericBuffer {
|
|
||||||
public:
|
public:
|
||||||
//!
|
//!
|
||||||
//! \brief Construct an empty buffer.
|
//! \brief Construct an empty buffer.
|
||||||
//!
|
//!
|
||||||
explicit FDGenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT)
|
explicit FDGenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT)
|
||||||
: mSize(0),
|
: mSize(0), mCapacity(0), mType(type), mBuffer(nullptr),
|
||||||
mCapacity(0),
|
|
||||||
mType(type),
|
|
||||||
mBuffer(nullptr),
|
|
||||||
mExternal_buffer(nullptr) {}
|
mExternal_buffer(nullptr) {}
|
||||||
|
|
||||||
//!
|
//!
|
||||||
@@ -104,9 +98,7 @@ class FDGenericBuffer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
FDGenericBuffer(FDGenericBuffer&& buf)
|
FDGenericBuffer(FDGenericBuffer&& buf)
|
||||||
: mSize(buf.mSize),
|
: mSize(buf.mSize), mCapacity(buf.mCapacity), mType(buf.mType),
|
||||||
mCapacity(buf.mCapacity),
|
|
||||||
mType(buf.mType),
|
|
||||||
mBuffer(buf.mBuffer) {
|
mBuffer(buf.mBuffer) {
|
||||||
buf.mSize = 0;
|
buf.mSize = 0;
|
||||||
buf.mCapacity = 0;
|
buf.mCapacity = 0;
|
||||||
@@ -133,7 +125,8 @@ class FDGenericBuffer {
|
|||||||
//! \brief Returns pointer to underlying array.
|
//! \brief Returns pointer to underlying array.
|
||||||
//!
|
//!
|
||||||
void* data() {
|
void* data() {
|
||||||
if (mExternal_buffer != nullptr) return mExternal_buffer;
|
if (mExternal_buffer != nullptr)
|
||||||
|
return mExternal_buffer;
|
||||||
return mBuffer;
|
return mBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +134,8 @@ class FDGenericBuffer {
|
|||||||
//! \brief Returns pointer to underlying array.
|
//! \brief Returns pointer to underlying array.
|
||||||
//!
|
//!
|
||||||
const void* data() const {
|
const void* data() const {
|
||||||
if (mExternal_buffer != nullptr) return mExternal_buffer;
|
if (mExternal_buffer != nullptr)
|
||||||
|
return mExternal_buffer;
|
||||||
return mBuffer;
|
return mBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,8 +207,8 @@ class FDGenericBuffer {
|
|||||||
};
|
};
|
||||||
|
|
||||||
using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>;
|
using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>;
|
||||||
using FDDeviceHostBuffer = FDGenericBuffer<FDDeviceHostAllocator,
|
using FDDeviceHostBuffer =
|
||||||
FDDeviceHostFree>;
|
FDGenericBuffer<FDDeviceHostAllocator, FDDeviceHostFree>;
|
||||||
|
|
||||||
class FDTrtLogger : public nvinfer1::ILogger {
|
class FDTrtLogger : public nvinfer1::ILogger {
|
||||||
public:
|
public:
|
||||||
|
@@ -12,13 +12,14 @@
|
|||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/backends/common/multiclass_nms.h"
|
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "fastdeploy/core/fd_tensor.h"
|
#include "fastdeploy/core/fd_tensor.h"
|
||||||
#include "fastdeploy/utils/utils.h"
|
#include "fastdeploy/utils/utils.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace backend {
|
namespace vision {
|
||||||
|
namespace detection {
|
||||||
template <class T>
|
template <class T>
|
||||||
bool SortScorePairDescend(const std::pair<float, T>& pair1,
|
bool SortScorePairDescend(const std::pair<float, T>& pair1,
|
||||||
const std::pair<float, T>& pair2) {
|
const std::pair<float, T>& pair2) {
|
||||||
@@ -79,7 +80,7 @@ float JaccardOverlap(const float* box1, const float* box2,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void MultiClassNMS::FastNMS(const float* boxes, const float* scores,
|
void PaddleMultiClassNMS::FastNMS(const float* boxes, const float* scores,
|
||||||
const int& num_boxes,
|
const int& num_boxes,
|
||||||
std::vector<int>* keep_indices) {
|
std::vector<int>* keep_indices) {
|
||||||
std::vector<std::pair<float, int>> sorted_indices;
|
std::vector<std::pair<float, int>> sorted_indices;
|
||||||
@@ -109,7 +110,7 @@ void MultiClassNMS::FastNMS(const float* boxes, const float* scores,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int MultiClassNMS::NMSForEachSample(
|
int PaddleMultiClassNMS::NMSForEachSample(
|
||||||
const float* boxes, const float* scores, int num_boxes, int num_classes,
|
const float* boxes, const float* scores, int num_boxes, int num_classes,
|
||||||
std::map<int, std::vector<int>>* keep_indices) {
|
std::map<int, std::vector<int>>* keep_indices) {
|
||||||
for (int i = 0; i < num_classes; ++i) {
|
for (int i = 0; i < num_classes; ++i) {
|
||||||
@@ -152,7 +153,7 @@ int MultiClassNMS::NMSForEachSample(
|
|||||||
return num_det;
|
return num_det;
|
||||||
}
|
}
|
||||||
|
|
||||||
void MultiClassNMS::Compute(const float* boxes_data, const float* scores_data,
|
void PaddleMultiClassNMS::Compute(const float* boxes_data, const float* scores_data,
|
||||||
const std::vector<int64_t>& boxes_dim,
|
const std::vector<int64_t>& boxes_dim,
|
||||||
const std::vector<int64_t>& scores_dim) {
|
const std::vector<int64_t>& scores_dim) {
|
||||||
int score_size = scores_dim.size();
|
int score_size = scores_dim.size();
|
||||||
@@ -220,5 +221,6 @@ void MultiClassNMS::Compute(const float* boxes_data, const float* scores_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} // namespace backend
|
} // namespace detection
|
||||||
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -18,8 +18,9 @@
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
namespace backend {
|
namespace vision {
|
||||||
struct MultiClassNMS {
|
namespace detection {
|
||||||
|
struct PaddleMultiClassNMS {
|
||||||
int64_t background_label = -1;
|
int64_t background_label = -1;
|
||||||
int64_t keep_top_k = -1;
|
int64_t keep_top_k = -1;
|
||||||
float nms_eta;
|
float nms_eta;
|
||||||
@@ -40,6 +41,6 @@ struct MultiClassNMS {
|
|||||||
const std::vector<int64_t>& boxes_dim,
|
const std::vector<int64_t>& boxes_dim,
|
||||||
const std::vector<int64_t>& scores_dim);
|
const std::vector<int64_t>& scores_dim);
|
||||||
};
|
};
|
||||||
} // namespace backend
|
} // namespace detection
|
||||||
|
} // namespace vision
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
@@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/vision/detection/ppdet/postprocessor.h"
|
#include "fastdeploy/vision/detection/ppdet/postprocessor.h"
|
||||||
|
#include "fastdeploy/vision/detection/ppdet/multiclass_nms.h"
|
||||||
#include "fastdeploy/vision/utils/utils.h"
|
#include "fastdeploy/vision/utils/utils.h"
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
@@ -176,7 +177,7 @@ bool PaddleDetPostprocessor::ProcessUnDecodeResults(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
backend::MultiClassNMS nms;
|
PaddleMultiClassNMS nms;
|
||||||
nms.background_label = -1;
|
nms.background_label = -1;
|
||||||
nms.keep_top_k = 100;
|
nms.keep_top_k = 100;
|
||||||
nms.nms_eta = 1.0;
|
nms.nms_eta = 1.0;
|
||||||
|
Reference in New Issue
Block a user