mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 09:07:10 +08:00
FDTensor support GPU device (#190)
* fdtensor support GPU * TRT backend support GPU FDTensor * FDHostAllocator add FASTDEPLOY_DECL * fix FDTensor Data * fix FDTensor dtype Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
@@ -13,9 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/backends/tensorrt/trt_backend.h"
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "NvInferSafeRuntime.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include <cstring>
|
||||
#ifdef ENABLE_PADDLE_FRONTEND
|
||||
#include "paddle2onnx/converter.h"
|
||||
#endif
|
||||
@@ -210,9 +212,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
outputs_desc_.resize(onnx_reader.num_outputs);
|
||||
for (int i = 0; i < onnx_reader.num_inputs; ++i) {
|
||||
std::string name(onnx_reader.inputs[i].name);
|
||||
std::vector<int64_t> shape(onnx_reader.inputs[i].shape,
|
||||
onnx_reader.inputs[i].shape +
|
||||
onnx_reader.inputs[i].rank);
|
||||
std::vector<int64_t> shape(
|
||||
onnx_reader.inputs[i].shape,
|
||||
onnx_reader.inputs[i].shape + onnx_reader.inputs[i].rank);
|
||||
inputs_desc_[i].name = name;
|
||||
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
|
||||
@@ -231,9 +233,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
|
||||
for (int i = 0; i < onnx_reader.num_outputs; ++i) {
|
||||
std::string name(onnx_reader.outputs[i].name);
|
||||
std::vector<int64_t> shape(onnx_reader.outputs[i].shape,
|
||||
onnx_reader.outputs[i].shape +
|
||||
onnx_reader.outputs[i].rank);
|
||||
std::vector<int64_t> shape(
|
||||
onnx_reader.outputs[i].shape,
|
||||
onnx_reader.outputs[i].shape + onnx_reader.outputs[i].rank);
|
||||
outputs_desc_[i].name = name;
|
||||
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
outputs_desc_[i].dtype =
|
||||
@@ -286,24 +288,8 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
BuildTrtEngine();
|
||||
}
|
||||
|
||||
AllocateBufferInDynamicShape(inputs, outputs);
|
||||
std::vector<void*> input_binds(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (inputs[i].dtype == FDDataType::INT64) {
|
||||
int64_t* data = static_cast<int64_t*>(inputs[i].Data());
|
||||
std::vector<int32_t> casted_data(data, data + inputs[i].Numel());
|
||||
FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(),
|
||||
static_cast<void*>(casted_data.data()),
|
||||
inputs[i].Nbytes() / 2, cudaMemcpyHostToDevice,
|
||||
stream_) == 0,
|
||||
"[ERROR] Error occurs while copy memory from CPU to GPU.");
|
||||
} else {
|
||||
FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(),
|
||||
inputs[i].Data(), inputs[i].Nbytes(),
|
||||
cudaMemcpyHostToDevice, stream_) == 0,
|
||||
"[ERROR] Error occurs while copy memory from CPU to GPU.");
|
||||
}
|
||||
}
|
||||
SetInputs(inputs);
|
||||
AllocateOutputsBuffer(outputs);
|
||||
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
|
||||
FDERROR << "Failed to Infer with TensorRT." << std::endl;
|
||||
return false;
|
||||
@@ -339,18 +325,50 @@ void TrtBackend::GetInputOutputInfo() {
|
||||
bindings_.resize(num_binds);
|
||||
}
|
||||
|
||||
void TrtBackend::AllocateBufferInDynamicShape(
|
||||
const std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs) {
|
||||
void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
|
||||
for (const auto& item : inputs) {
|
||||
auto idx = engine_->getBindingIndex(item.name.c_str());
|
||||
std::vector<int> shape(item.shape.begin(), item.shape.end());
|
||||
auto dims = ToDims(shape);
|
||||
context_->setBindingDimensions(idx, dims);
|
||||
if (item.Nbytes() > inputs_buffer_[item.name].nbBytes()) {
|
||||
|
||||
if (item.device == Device::GPU) {
|
||||
if (item.dtype == FDDataType::INT64) {
|
||||
// TODO(liqi): cast int64 to int32
|
||||
// TRT don't support INT64
|
||||
FDASSERT(false,
|
||||
"TRT don't support INT64 input on GPU, "
|
||||
"please use INT32 input");
|
||||
} else {
|
||||
// no copy
|
||||
inputs_buffer_[item.name].SetExternalData(dims, item.Data());
|
||||
}
|
||||
} else {
|
||||
// Allocate input buffer memory
|
||||
inputs_buffer_[item.name].resize(dims);
|
||||
bindings_[idx] = inputs_buffer_[item.name].data();
|
||||
|
||||
// copy from cpu to gpu
|
||||
if (item.dtype == FDDataType::INT64) {
|
||||
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
|
||||
std::vector<int32_t> casted_data(data, data + item.Numel());
|
||||
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(),
|
||||
static_cast<void*>(casted_data.data()),
|
||||
item.Nbytes() / 2, cudaMemcpyHostToDevice,
|
||||
stream_) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
} else {
|
||||
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(),
|
||||
item.Nbytes(), cudaMemcpyHostToDevice,
|
||||
stream_) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
}
|
||||
}
|
||||
// binding input buffer
|
||||
bindings_[idx] = inputs_buffer_[item.name].data();
|
||||
}
|
||||
}
|
||||
|
||||
void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
|
||||
if (outputs->size() != outputs_desc_.size()) {
|
||||
outputs->resize(outputs_desc_.size());
|
||||
}
|
||||
@@ -365,13 +383,15 @@ void TrtBackend::AllocateBufferInDynamicShape(
|
||||
"Cannot find output: %s of tensorrt network from the original model.",
|
||||
outputs_desc_[i].name.c_str());
|
||||
auto ori_idx = iter->second;
|
||||
std::vector<int64_t> shape(output_dims.d, output_dims.d + output_dims.nbDims);
|
||||
(*outputs)[ori_idx].Allocate(shape, GetFDDataType(outputs_desc_[i].dtype), outputs_desc_[i].name);
|
||||
if ((*outputs)[ori_idx].Nbytes() >
|
||||
outputs_buffer_[outputs_desc_[i].name].nbBytes()) {
|
||||
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
|
||||
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
|
||||
}
|
||||
// set user's outputs info
|
||||
std::vector<int64_t> shape(output_dims.d,
|
||||
output_dims.d + output_dims.nbDims);
|
||||
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
|
||||
outputs_desc_[i].name);
|
||||
// Allocate output buffer memory
|
||||
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
|
||||
// binding output buffer
|
||||
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -580,4 +600,4 @@ TensorInfo TrtBackend::GetOutputInfo(int index) {
|
||||
info.dtype = GetFDDataType(outputs_desc_[index].dtype);
|
||||
return info;
|
||||
}
|
||||
} // namespace fastdeploy
|
||||
} // namespace fastdeploy
|
||||
|
Reference in New Issue
Block a user