mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
FDTensor support GPU device (#190)
* fdtensor support GPU * TRT backend support GPU FDTensor * FDHostAllocator add FASTDEPLOY_DECL * fix FDTensor Data * fix FDTensor dtype Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
@@ -13,7 +13,9 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/backends/ort/ort_backend.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "fastdeploy/backends/ort/ops/multiclass_nms.h"
|
||||
#include "fastdeploy/backends/ort/utils.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
@@ -164,33 +166,34 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
return true;
|
||||
}
|
||||
|
||||
void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name) {
|
||||
void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor,
|
||||
const std::string& name) {
|
||||
const auto info = value.GetTensorTypeAndShapeInfo();
|
||||
const auto data_type = info.GetElementType();
|
||||
size_t numel = info.GetElementCount();
|
||||
auto shape = info.GetShape();
|
||||
FDDataType dtype;
|
||||
|
||||
if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
|
||||
tensor->Allocate(info.GetShape(), FDDataType::FP32, name);
|
||||
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
|
||||
numel * sizeof(float));
|
||||
dtype = FDDataType::FP32;
|
||||
numel *= sizeof(float);
|
||||
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
|
||||
tensor->Allocate(info.GetShape(), FDDataType::INT32, name);
|
||||
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
|
||||
numel * sizeof(int32_t));
|
||||
dtype = FDDataType::INT32;
|
||||
numel *= sizeof(int32_t);
|
||||
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
|
||||
tensor->Allocate(info.GetShape(), FDDataType::INT64, name);
|
||||
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
|
||||
numel * sizeof(int64_t));
|
||||
dtype = FDDataType::INT64;
|
||||
numel *= sizeof(int64_t);
|
||||
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
|
||||
tensor->Allocate(info.GetShape(), FDDataType::FP64, name);
|
||||
memcpy(static_cast<void*>(tensor->MutableData()), value.GetTensorData<void*>(),
|
||||
numel * sizeof(double));
|
||||
dtype = FDDataType::FP64;
|
||||
numel *= sizeof(double);
|
||||
} else {
|
||||
FDASSERT(
|
||||
false,
|
||||
"Unrecognized data type of %d while calling OrtBackend::CopyToCpu().",
|
||||
data_type);
|
||||
}
|
||||
tensor->Resize(shape, dtype, name);
|
||||
memcpy(tensor->MutableData(), value.GetTensorData<void*>(), numel);
|
||||
}
|
||||
|
||||
bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
|
@@ -88,6 +88,7 @@ class OrtBackend : public BaseBackend {
|
||||
Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle");
|
||||
#endif
|
||||
OrtBackendOption option_;
|
||||
void CopyToCpu(const Ort::Value& value, FDTensor* tensor, const std::string& name);
|
||||
void CopyToCpu(const Ort::Value& value, FDTensor* tensor,
|
||||
const std::string& name);
|
||||
};
|
||||
} // namespace fastdeploy
|
||||
|
@@ -79,16 +79,23 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
||||
}
|
||||
|
||||
TensorInfo PaddleBackend::GetInputInfo(int index) {
|
||||
FDASSERT(index < NumInputs(), "The index: %d should less than the number of inputs: %d.", index, NumInputs());
|
||||
FDASSERT(index < NumInputs(),
|
||||
"The index: %d should less than the number of inputs: %d.", index,
|
||||
NumInputs());
|
||||
return inputs_desc_[index];
|
||||
}
|
||||
|
||||
std::vector<TensorInfo> PaddleBackend::GetInputInfo() { return inputs_desc_; }
|
||||
|
||||
TensorInfo PaddleBackend::GetOutputInfo(int index) {
|
||||
FDASSERT(index < NumOutputs(),
|
||||
"The index: %d should less than the number of outputs %d.", index, NumOutputs());
|
||||
"The index: %d should less than the number of outputs %d.", index,
|
||||
NumOutputs());
|
||||
return outputs_desc_[index];
|
||||
}
|
||||
|
||||
std::vector<TensorInfo> PaddleBackend::GetOutputInfo() { return outputs_desc_; }
|
||||
|
||||
bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
if (inputs.size() != inputs_desc_.size()) {
|
||||
@@ -100,7 +107,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
auto handle = predictor_->GetInputHandle(inputs[i].name);
|
||||
ShareTensorFromCpu(handle.get(), inputs[i]);
|
||||
ShareTensorFromFDTensor(handle.get(), inputs[i]);
|
||||
}
|
||||
|
||||
predictor_->Run();
|
||||
|
@@ -44,8 +44,11 @@ struct PaddleBackendOption {
|
||||
std::vector<std::string> delete_pass_names = {};
|
||||
};
|
||||
|
||||
// convert FD device to paddle place type
|
||||
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device);
|
||||
|
||||
// Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor
|
||||
void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
|
||||
void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
|
||||
|
||||
// Copy memory data from paddle_infer::Tensor to fastdeploy::FDTensor
|
||||
void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||
@@ -72,6 +75,8 @@ class PaddleBackend : public BaseBackend {
|
||||
|
||||
TensorInfo GetInputInfo(int index);
|
||||
TensorInfo GetOutputInfo(int index);
|
||||
std::vector<TensorInfo> GetInputInfo();
|
||||
std::vector<TensorInfo> GetOutputInfo();
|
||||
|
||||
private:
|
||||
paddle_infer::Config config_;
|
||||
|
@@ -15,23 +15,33 @@
|
||||
#include "fastdeploy/backends/paddle/paddle_backend.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
void ShareTensorFromCpu(paddle_infer::Tensor* tensor, FDTensor& fd_tensor) {
|
||||
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device) {
|
||||
if (device == Device::GPU) {
|
||||
return paddle_infer::PlaceType::kGPU;
|
||||
}
|
||||
return paddle_infer::PlaceType::kCPU;
|
||||
}
|
||||
|
||||
void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
|
||||
FDTensor& fd_tensor) {
|
||||
std::vector<int> shape(fd_tensor.shape.begin(), fd_tensor.shape.end());
|
||||
tensor->Reshape(shape);
|
||||
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
|
||||
if (fd_tensor.dtype == FDDataType::FP32) {
|
||||
tensor->ShareExternalData(static_cast<const float*>(fd_tensor.Data()),
|
||||
shape, paddle_infer::PlaceType::kCPU);
|
||||
shape, place);
|
||||
return;
|
||||
} else if (fd_tensor.dtype == FDDataType::INT32) {
|
||||
tensor->ShareExternalData(static_cast<const int32_t*>(fd_tensor.Data()),
|
||||
shape, paddle_infer::PlaceType::kCPU);
|
||||
shape, place);
|
||||
return;
|
||||
} else if (fd_tensor.dtype == FDDataType::INT64) {
|
||||
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
|
||||
shape, paddle_infer::PlaceType::kCPU);
|
||||
shape, place);
|
||||
return;
|
||||
}
|
||||
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor.dtype).c_str());
|
||||
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
|
||||
Str(fd_tensor.dtype).c_str());
|
||||
}
|
||||
|
||||
void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||
@@ -51,7 +61,8 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor->MutableData()));
|
||||
return;
|
||||
}
|
||||
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.", Str(fd_tensor->dtype).c_str());
|
||||
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
|
||||
Str(fd_tensor->dtype).c_str());
|
||||
}
|
||||
|
||||
FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
|
||||
@@ -65,7 +76,10 @@ FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
|
||||
} else if (dtype == paddle_infer::UINT8) {
|
||||
fd_dtype = FDDataType::UINT8;
|
||||
} else {
|
||||
FDASSERT(false, "Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.", int(dtype));
|
||||
FDASSERT(
|
||||
false,
|
||||
"Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.",
|
||||
int(dtype));
|
||||
}
|
||||
return fd_dtype;
|
||||
}
|
||||
|
@@ -13,9 +13,11 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/backends/tensorrt/trt_backend.h"
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "NvInferSafeRuntime.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include <cstring>
|
||||
#ifdef ENABLE_PADDLE_FRONTEND
|
||||
#include "paddle2onnx/converter.h"
|
||||
#endif
|
||||
@@ -210,9 +212,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
outputs_desc_.resize(onnx_reader.num_outputs);
|
||||
for (int i = 0; i < onnx_reader.num_inputs; ++i) {
|
||||
std::string name(onnx_reader.inputs[i].name);
|
||||
std::vector<int64_t> shape(onnx_reader.inputs[i].shape,
|
||||
onnx_reader.inputs[i].shape +
|
||||
onnx_reader.inputs[i].rank);
|
||||
std::vector<int64_t> shape(
|
||||
onnx_reader.inputs[i].shape,
|
||||
onnx_reader.inputs[i].shape + onnx_reader.inputs[i].rank);
|
||||
inputs_desc_[i].name = name;
|
||||
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
|
||||
@@ -231,9 +233,9 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
|
||||
for (int i = 0; i < onnx_reader.num_outputs; ++i) {
|
||||
std::string name(onnx_reader.outputs[i].name);
|
||||
std::vector<int64_t> shape(onnx_reader.outputs[i].shape,
|
||||
onnx_reader.outputs[i].shape +
|
||||
onnx_reader.outputs[i].rank);
|
||||
std::vector<int64_t> shape(
|
||||
onnx_reader.outputs[i].shape,
|
||||
onnx_reader.outputs[i].shape + onnx_reader.outputs[i].rank);
|
||||
outputs_desc_[i].name = name;
|
||||
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||
outputs_desc_[i].dtype =
|
||||
@@ -286,24 +288,8 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
BuildTrtEngine();
|
||||
}
|
||||
|
||||
AllocateBufferInDynamicShape(inputs, outputs);
|
||||
std::vector<void*> input_binds(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
if (inputs[i].dtype == FDDataType::INT64) {
|
||||
int64_t* data = static_cast<int64_t*>(inputs[i].Data());
|
||||
std::vector<int32_t> casted_data(data, data + inputs[i].Numel());
|
||||
FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(),
|
||||
static_cast<void*>(casted_data.data()),
|
||||
inputs[i].Nbytes() / 2, cudaMemcpyHostToDevice,
|
||||
stream_) == 0,
|
||||
"[ERROR] Error occurs while copy memory from CPU to GPU.");
|
||||
} else {
|
||||
FDASSERT(cudaMemcpyAsync(inputs_buffer_[inputs[i].name].data(),
|
||||
inputs[i].Data(), inputs[i].Nbytes(),
|
||||
cudaMemcpyHostToDevice, stream_) == 0,
|
||||
"[ERROR] Error occurs while copy memory from CPU to GPU.");
|
||||
}
|
||||
}
|
||||
SetInputs(inputs);
|
||||
AllocateOutputsBuffer(outputs);
|
||||
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
|
||||
FDERROR << "Failed to Infer with TensorRT." << std::endl;
|
||||
return false;
|
||||
@@ -339,18 +325,50 @@ void TrtBackend::GetInputOutputInfo() {
|
||||
bindings_.resize(num_binds);
|
||||
}
|
||||
|
||||
void TrtBackend::AllocateBufferInDynamicShape(
|
||||
const std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs) {
|
||||
void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
|
||||
for (const auto& item : inputs) {
|
||||
auto idx = engine_->getBindingIndex(item.name.c_str());
|
||||
std::vector<int> shape(item.shape.begin(), item.shape.end());
|
||||
auto dims = ToDims(shape);
|
||||
context_->setBindingDimensions(idx, dims);
|
||||
if (item.Nbytes() > inputs_buffer_[item.name].nbBytes()) {
|
||||
|
||||
if (item.device == Device::GPU) {
|
||||
if (item.dtype == FDDataType::INT64) {
|
||||
// TODO(liqi): cast int64 to int32
|
||||
// TRT don't support INT64
|
||||
FDASSERT(false,
|
||||
"TRT don't support INT64 input on GPU, "
|
||||
"please use INT32 input");
|
||||
} else {
|
||||
// no copy
|
||||
inputs_buffer_[item.name].SetExternalData(dims, item.Data());
|
||||
}
|
||||
} else {
|
||||
// Allocate input buffer memory
|
||||
inputs_buffer_[item.name].resize(dims);
|
||||
|
||||
// copy from cpu to gpu
|
||||
if (item.dtype == FDDataType::INT64) {
|
||||
int64_t* data = static_cast<int64_t*>(const_cast<void*>(item.Data()));
|
||||
std::vector<int32_t> casted_data(data, data + item.Numel());
|
||||
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(),
|
||||
static_cast<void*>(casted_data.data()),
|
||||
item.Nbytes() / 2, cudaMemcpyHostToDevice,
|
||||
stream_) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
} else {
|
||||
FDASSERT(cudaMemcpyAsync(inputs_buffer_[item.name].data(), item.Data(),
|
||||
item.Nbytes(), cudaMemcpyHostToDevice,
|
||||
stream_) == 0,
|
||||
"Error occurs while copy memory from CPU to GPU.");
|
||||
}
|
||||
}
|
||||
// binding input buffer
|
||||
bindings_[idx] = inputs_buffer_[item.name].data();
|
||||
}
|
||||
}
|
||||
|
||||
void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
|
||||
if (outputs->size() != outputs_desc_.size()) {
|
||||
outputs->resize(outputs_desc_.size());
|
||||
}
|
||||
@@ -365,15 +383,17 @@ void TrtBackend::AllocateBufferInDynamicShape(
|
||||
"Cannot find output: %s of tensorrt network from the original model.",
|
||||
outputs_desc_[i].name.c_str());
|
||||
auto ori_idx = iter->second;
|
||||
std::vector<int64_t> shape(output_dims.d, output_dims.d + output_dims.nbDims);
|
||||
(*outputs)[ori_idx].Allocate(shape, GetFDDataType(outputs_desc_[i].dtype), outputs_desc_[i].name);
|
||||
if ((*outputs)[ori_idx].Nbytes() >
|
||||
outputs_buffer_[outputs_desc_[i].name].nbBytes()) {
|
||||
// set user's outputs info
|
||||
std::vector<int64_t> shape(output_dims.d,
|
||||
output_dims.d + output_dims.nbDims);
|
||||
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
|
||||
outputs_desc_[i].name);
|
||||
// Allocate output buffer memory
|
||||
outputs_buffer_[outputs_desc_[i].name].resize(output_dims);
|
||||
// binding output buffer
|
||||
bindings_[idx] = outputs_buffer_[outputs_desc_[i].name].data();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool TrtBackend::BuildTrtEngine() {
|
||||
auto config =
|
||||
|
@@ -14,6 +14,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <string>
|
||||
@@ -23,7 +25,6 @@
|
||||
#include "NvOnnxParser.h"
|
||||
#include "fastdeploy/backends/backend.h"
|
||||
#include "fastdeploy/backends/tensorrt/utils.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
@@ -109,12 +110,12 @@ class TrtBackend : public BaseBackend {
|
||||
std::map<std::string, ShapeRangeInfo> shape_range_info_;
|
||||
|
||||
void GetInputOutputInfo();
|
||||
void AllocateBufferInDynamicShape(const std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs);
|
||||
bool CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer);
|
||||
bool BuildTrtEngine();
|
||||
bool LoadTrtCache(const std::string& trt_engine_file);
|
||||
int ShapeRangeInfoUpdated(const std::vector<FDTensor>& inputs);
|
||||
void SetInputs(const std::vector<FDTensor>& inputs);
|
||||
void AllocateOutputsBuffer(std::vector<FDTensor>* outputs);
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -14,11 +14,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "NvInfer.h"
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
#include <algorithm>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
@@ -26,17 +24,24 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "NvInfer.h"
|
||||
#include "fastdeploy/core/allocate.h"
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
struct FDInferDeleter {
|
||||
template <typename T> void operator()(T* obj) const {
|
||||
template <typename T>
|
||||
void operator()(T* obj) const {
|
||||
if (obj) {
|
||||
obj->destroy();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T> using FDUniquePtr = std::unique_ptr<T, FDInferDeleter>;
|
||||
template <typename T>
|
||||
using FDUniquePtr = std::unique_ptr<T, FDInferDeleter>;
|
||||
|
||||
int64_t Volume(const nvinfer1::Dims& d);
|
||||
|
||||
@@ -64,13 +69,18 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& vec) {
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
|
||||
template <typename AllocFunc, typename FreeFunc>
|
||||
class FDGenericBuffer {
|
||||
public:
|
||||
//!
|
||||
//! \brief Construct an empty buffer.
|
||||
//!
|
||||
explicit FDGenericBuffer(nvinfer1::DataType type = nvinfer1::DataType::kFLOAT)
|
||||
: mSize(0), mCapacity(0), mType(type), mBuffer(nullptr) {}
|
||||
: mSize(0),
|
||||
mCapacity(0),
|
||||
mType(type),
|
||||
mBuffer(nullptr),
|
||||
mExternal_buffer(nullptr) {}
|
||||
|
||||
//!
|
||||
//! \brief Construct a buffer with the specified allocation size in bytes.
|
||||
@@ -82,8 +92,18 @@ template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
|
||||
}
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief This use to skip memory copy step.
|
||||
//!
|
||||
FDGenericBuffer(size_t size, nvinfer1::DataType type, void* buffer)
|
||||
: mSize(size), mCapacity(size), mType(type) {
|
||||
mExternal_buffer = buffer;
|
||||
}
|
||||
|
||||
FDGenericBuffer(FDGenericBuffer&& buf)
|
||||
: mSize(buf.mSize), mCapacity(buf.mCapacity), mType(buf.mType),
|
||||
: mSize(buf.mSize),
|
||||
mCapacity(buf.mCapacity),
|
||||
mType(buf.mType),
|
||||
mBuffer(buf.mBuffer) {
|
||||
buf.mSize = 0;
|
||||
buf.mCapacity = 0;
|
||||
@@ -109,12 +129,18 @@ template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
|
||||
//!
|
||||
//! \brief Returns pointer to underlying array.
|
||||
//!
|
||||
void* data() { return mBuffer; }
|
||||
void* data() {
|
||||
if (mExternal_buffer != nullptr) return mExternal_buffer;
|
||||
return mBuffer;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Returns pointer to underlying array.
|
||||
//!
|
||||
const void* data() const { return mBuffer; }
|
||||
const void* data() const {
|
||||
if (mExternal_buffer != nullptr) return mExternal_buffer;
|
||||
return mBuffer;
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Returns the size (in number of elements) of the buffer.
|
||||
@@ -126,11 +152,29 @@ template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
|
||||
//!
|
||||
size_t nbBytes() const { return this->size() * TrtDataTypeSize(mType); }
|
||||
|
||||
//!
|
||||
//! \brief Set user memory buffer for TRT Buffer
|
||||
//!
|
||||
void SetExternalData(size_t size, nvinfer1::DataType type, void* buffer) {
|
||||
mSize = mCapacity = size;
|
||||
mType = type;
|
||||
mExternal_buffer = const_cast<void*>(buffer);
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Set user memory buffer for TRT Buffer
|
||||
//!
|
||||
void SetExternalData(const nvinfer1::Dims& dims, const void* buffer) {
|
||||
mSize = mCapacity = Volume(dims);
|
||||
mExternal_buffer = const_cast<void*>(buffer);
|
||||
}
|
||||
|
||||
//!
|
||||
//! \brief Resizes the buffer. This is a no-op if the new size is smaller than
|
||||
//! or equal to the current capacity.
|
||||
//!
|
||||
void resize(size_t newSize) {
|
||||
mExternal_buffer = nullptr;
|
||||
mSize = newSize;
|
||||
if (mCapacity < newSize) {
|
||||
freeFn(mBuffer);
|
||||
@@ -146,28 +190,20 @@ template <typename AllocFunc, typename FreeFunc> class FDGenericBuffer {
|
||||
//!
|
||||
void resize(const nvinfer1::Dims& dims) { return this->resize(Volume(dims)); }
|
||||
|
||||
~FDGenericBuffer() { freeFn(mBuffer); }
|
||||
~FDGenericBuffer() {
|
||||
mExternal_buffer = nullptr;
|
||||
freeFn(mBuffer);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t mSize{0}, mCapacity{0};
|
||||
nvinfer1::DataType mType;
|
||||
void* mBuffer;
|
||||
void* mExternal_buffer;
|
||||
AllocFunc allocFn;
|
||||
FreeFunc freeFn;
|
||||
};
|
||||
|
||||
class FDDeviceAllocator {
|
||||
public:
|
||||
bool operator()(void** ptr, size_t size) const {
|
||||
return cudaMalloc(ptr, size) == cudaSuccess;
|
||||
}
|
||||
};
|
||||
|
||||
class FDDeviceFree {
|
||||
public:
|
||||
void operator()(void* ptr) const { cudaFree(ptr); }
|
||||
};
|
||||
|
||||
using FDDeviceBuffer = FDGenericBuffer<FDDeviceAllocator, FDDeviceFree>;
|
||||
|
||||
class FDTrtLogger : public nvinfer1::ILogger {
|
||||
@@ -197,7 +233,7 @@ class FDTrtLogger : public nvinfer1::ILogger {
|
||||
};
|
||||
|
||||
struct ShapeRangeInfo {
|
||||
ShapeRangeInfo(const std::vector<int64_t>& new_shape) {
|
||||
explicit ShapeRangeInfo(const std::vector<int64_t>& new_shape) {
|
||||
shape.assign(new_shape.begin(), new_shape.end());
|
||||
min.resize(new_shape.size());
|
||||
max.resize(new_shape.size());
|
||||
|
41
csrc/fastdeploy/core/allocate.cc
Normal file
41
csrc/fastdeploy/core/allocate.cc
Normal file
@@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#ifdef WITH_GPU
|
||||
#include <cuda_runtime_api.h>
|
||||
#endif
|
||||
|
||||
#include "fastdeploy/core/allocate.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
bool FDHostAllocator::operator()(void** ptr, size_t size) const {
|
||||
*ptr = malloc(size);
|
||||
return *ptr != nullptr;
|
||||
}
|
||||
|
||||
void FDHostFree::operator()(void* ptr) const { free(ptr); }
|
||||
|
||||
#ifdef WITH_GPU
|
||||
|
||||
bool FDDeviceAllocator::operator()(void** ptr, size_t size) const {
|
||||
return cudaMalloc(ptr, size) == cudaSuccess;
|
||||
}
|
||||
|
||||
void FDDeviceFree::operator()(void* ptr) const { cudaFree(ptr); }
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace fastdeploy
|
50
csrc/fastdeploy/core/allocate.h
Normal file
50
csrc/fastdeploy/core/allocate.h
Normal file
@@ -0,0 +1,50 @@
|
||||
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <new>
|
||||
#include <numeric>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
class FASTDEPLOY_DECL FDHostAllocator {
|
||||
public:
|
||||
bool operator()(void** ptr, size_t size) const;
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL FDHostFree {
|
||||
public:
|
||||
void operator()(void* ptr) const;
|
||||
};
|
||||
|
||||
#ifdef WITH_GPU
|
||||
|
||||
class FASTDEPLOY_DECL FDDeviceAllocator {
|
||||
public:
|
||||
bool operator()(void** ptr, size_t size) const;
|
||||
};
|
||||
|
||||
class FASTDEPLOY_DECL FDDeviceFree {
|
||||
public:
|
||||
void operator()(void* ptr) const;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace fastdeploy
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/core/fd_tensor.h"
|
||||
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
|
||||
#ifdef WITH_GPU
|
||||
@@ -25,55 +26,69 @@ void* FDTensor::MutableData() {
|
||||
if (external_data_ptr != nullptr) {
|
||||
return external_data_ptr;
|
||||
}
|
||||
return data.data();
|
||||
return buffer_;
|
||||
}
|
||||
|
||||
void* FDTensor::Data() {
|
||||
if (external_data_ptr != nullptr) {
|
||||
if (device == Device::GPU) {
|
||||
#ifdef WITH_GPU
|
||||
// need to copy cuda mem to cpu first
|
||||
temporary_cpu_buffer.resize(Nbytes());
|
||||
FDASSERT(cudaMemcpy(temporary_cpu_buffer.data(), external_data_ptr,
|
||||
Nbytes(), cudaMemcpyDeviceToHost) == 0,
|
||||
"[ERROR] Error occurs while copy memory from GPU to CPU");
|
||||
return temporary_cpu_buffer.data();
|
||||
#else
|
||||
FDASSERT(false,
|
||||
"The FastDeploy didn't compile under -DWITH_GPU=ON, so this is "
|
||||
"an unexpected problem happend.");
|
||||
#endif
|
||||
} else {
|
||||
return external_data_ptr;
|
||||
}
|
||||
}
|
||||
return data.data();
|
||||
return buffer_;
|
||||
}
|
||||
|
||||
const void* FDTensor::Data() const {
|
||||
if (external_data_ptr != nullptr) {
|
||||
return external_data_ptr;
|
||||
}
|
||||
return data.data();
|
||||
return buffer_;
|
||||
}
|
||||
|
||||
const void* FDTensor::CpuData() const {
|
||||
if (device == Device::GPU) {
|
||||
#ifdef WITH_GPU
|
||||
auto* cpu_ptr = const_cast<std::vector<int8_t>*>(&temporary_cpu_buffer);
|
||||
cpu_ptr->resize(Nbytes());
|
||||
// need to copy cuda mem to cpu first
|
||||
if (external_data_ptr != nullptr) {
|
||||
FDASSERT(cudaMemcpy(cpu_ptr->data(), external_data_ptr, Nbytes(),
|
||||
cudaMemcpyDeviceToHost) == 0,
|
||||
"[ERROR] Error occurs while copy memory from GPU to CPU");
|
||||
|
||||
} else {
|
||||
FDASSERT(cudaMemcpy(cpu_ptr->data(), buffer_, Nbytes(),
|
||||
cudaMemcpyDeviceToHost) == 0,
|
||||
"[ERROR] Error occurs while buffer copy memory from GPU to CPU");
|
||||
}
|
||||
return cpu_ptr->data();
|
||||
#else
|
||||
FDASSERT(false,
|
||||
"The FastDeploy didn't compile under -DWITH_GPU=ON, so this is "
|
||||
"an unexpected problem happend.");
|
||||
#endif
|
||||
}
|
||||
return Data();
|
||||
}
|
||||
|
||||
void FDTensor::SetExternalData(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type, void* data_buffer) {
|
||||
const FDDataType& data_type, void* data_buffer,
|
||||
const Device& new_device) {
|
||||
dtype = data_type;
|
||||
shape.assign(new_shape.begin(), new_shape.end());
|
||||
external_data_ptr = data_buffer;
|
||||
device = new_device;
|
||||
}
|
||||
|
||||
void FDTensor::Allocate(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type,
|
||||
const std::string& tensor_name) {
|
||||
const std::string& tensor_name,
|
||||
const Device& new_device) {
|
||||
dtype = data_type;
|
||||
name = tensor_name;
|
||||
shape.assign(new_shape.begin(), new_shape.end());
|
||||
int unit = FDDataTypeSize(data_type);
|
||||
int total_size =
|
||||
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
||||
data.resize(total_size * unit);
|
||||
device = new_device;
|
||||
size_t nbytes = Nbytes();
|
||||
FDASSERT(AllocFn(nbytes),
|
||||
"The FastDeploy FDTensor allocate cpu memory error");
|
||||
}
|
||||
|
||||
int FDTensor::Nbytes() const { return Numel() * FDDataTypeSize(dtype); }
|
||||
@@ -82,6 +97,44 @@ int FDTensor::Numel() const {
|
||||
return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
||||
}
|
||||
|
||||
void FDTensor::Resize(size_t new_nbytes) {
|
||||
size_t nbytes = Nbytes();
|
||||
if (new_nbytes > nbytes) {
|
||||
FreeFn();
|
||||
AllocFn(new_nbytes);
|
||||
}
|
||||
}
|
||||
|
||||
void FDTensor::Resize(const std::vector<int64_t>& new_shape) {
|
||||
int numel = Numel();
|
||||
int new_numel = std::accumulate(new_shape.begin(), new_shape.end(), 1,
|
||||
std::multiplies<int>());
|
||||
shape.assign(new_shape.begin(), new_shape.end());
|
||||
if (new_numel > numel) {
|
||||
FreeFn();
|
||||
size_t nbytes = new_numel * FDDataTypeSize(dtype);
|
||||
AllocFn(nbytes);
|
||||
}
|
||||
}
|
||||
|
||||
void FDTensor::Resize(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type,
|
||||
const std::string& tensor_name,
|
||||
const Device& new_device) {
|
||||
name = tensor_name;
|
||||
device = new_device;
|
||||
size_t nbytes = Nbytes();
|
||||
shape.assign(new_shape.begin(), new_shape.end());
|
||||
dtype = data_type;
|
||||
int new_nbytes = std::accumulate(new_shape.begin(), new_shape.end(), 1,
|
||||
std::multiplies<int>()) *
|
||||
FDDataTypeSize(data_type);
|
||||
if (new_nbytes > nbytes) {
|
||||
FreeFn();
|
||||
AllocFn(new_nbytes);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalculateStatisInfo(void* src_ptr, int size, double* mean, double* max,
|
||||
double* min) {
|
||||
|
@@ -18,15 +18,17 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "fastdeploy/core/allocate.h"
|
||||
#include "fastdeploy/core/fd_type.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
struct FASTDEPLOY_DECL FDTensor {
|
||||
std::vector<int8_t> data;
|
||||
std::vector<int64_t> shape;
|
||||
// std::vector<int8_t> data;
|
||||
void* buffer_ = nullptr;
|
||||
std::vector<int64_t> shape = {0};
|
||||
std::string name = "";
|
||||
FDDataType dtype;
|
||||
FDDataType dtype = FDDataType::INT8;
|
||||
|
||||
// This use to skip memory copy step
|
||||
// the external_data_ptr will point to the user allocated memory
|
||||
@@ -46,28 +48,32 @@ struct FASTDEPLOY_DECL FDTensor {
|
||||
// Get data buffer pointer
|
||||
void* MutableData();
|
||||
|
||||
// Use this data to get the tensor data to process
|
||||
// Since the most senario is process data in CPU
|
||||
// this function weill return a pointer to cpu memory
|
||||
// buffer.
|
||||
// If the original data is on other device, the data
|
||||
// will copy to cpu store in `temporary_cpu_buffer`
|
||||
void* Data();
|
||||
|
||||
const void* Data() const;
|
||||
|
||||
// Use this data to get the tensor data to process
|
||||
// Since the most senario is process data in CPU
|
||||
// this function will return a pointer to cpu memory
|
||||
// buffer.
|
||||
// If the original data is on other device, the data
|
||||
// will copy to cpu store in `temporary_cpu_buffer`
|
||||
const void* CpuData() const;
|
||||
|
||||
// Set user memory buffer for Tensor, the memory is managed by
|
||||
// the user it self, but the Tensor will share the memory with user
|
||||
// So take care with the user buffer
|
||||
void SetExternalData(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type, void* data_buffer);
|
||||
const FDDataType& data_type, void* data_buffer,
|
||||
const Device& new_device = Device::CPU);
|
||||
|
||||
// Initialize Tensor
|
||||
// Include setting attribute for tensor
|
||||
// and allocate cpu memory buffer
|
||||
void Allocate(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type,
|
||||
const std::string& tensor_name = "");
|
||||
const std::string& tensor_name = "",
|
||||
const Device& new_device = Device::CPU);
|
||||
|
||||
// Total size of tensor memory buffer in bytes
|
||||
int Nbytes() const;
|
||||
@@ -75,13 +81,51 @@ struct FASTDEPLOY_DECL FDTensor {
|
||||
// Total number of elements in this tensor
|
||||
int Numel() const;
|
||||
|
||||
void Resize(size_t nbytes);
|
||||
|
||||
void Resize(const std::vector<int64_t>& new_shape);
|
||||
|
||||
void Resize(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type, const std::string& tensor_name = "",
|
||||
const Device& new_device = Device::CPU);
|
||||
|
||||
// Debug function
|
||||
// Use this function to print shape, dtype, mean, max, min
|
||||
// prefix will also be printed as tag
|
||||
void PrintInfo(const std::string& prefix = "TensorInfo: ");
|
||||
|
||||
bool AllocFn(size_t nbytes) {
|
||||
if (device == Device::GPU) {
|
||||
#ifdef WITH_GPU
|
||||
return FDDeviceAllocator()(&buffer_, nbytes);
|
||||
#else
|
||||
FDASSERT(false,
|
||||
"The FastDeploy FDTensor allocator didn't compile under "
|
||||
"-DWITH_GPU=ON,"
|
||||
"so this is an unexpected problem happend.");
|
||||
#endif
|
||||
}
|
||||
return FDHostAllocator()(&buffer_, nbytes);
|
||||
}
|
||||
|
||||
void FreeFn() {
|
||||
if (external_data_ptr != nullptr) external_data_ptr = nullptr;
|
||||
if (buffer_ != nullptr) {
|
||||
if (device == Device::GPU) {
|
||||
#ifdef WITH_GPU
|
||||
FDDeviceFree()(buffer_);
|
||||
#endif
|
||||
} else {
|
||||
FDHostFree()(buffer_);
|
||||
}
|
||||
buffer_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
FDTensor() {}
|
||||
explicit FDTensor(const std::string& tensor_name);
|
||||
|
||||
~FDTensor() { FreeFn(); }
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -13,6 +13,7 @@
|
||||
// limitations under the License.
|
||||
|
||||
#include "fastdeploy/core/fd_type.h"
|
||||
|
||||
#include "fastdeploy/utils/utils.h"
|
||||
|
||||
namespace fastdeploy {
|
||||
@@ -33,6 +34,8 @@ int FDDataTypeSize(const FDDataType& data_type) {
|
||||
return sizeof(double);
|
||||
} else if (data_type == FDDataType::UINT8) {
|
||||
return sizeof(uint8_t);
|
||||
} else if (data_type == FDDataType::INT8) {
|
||||
return sizeof(int8_t);
|
||||
} else {
|
||||
FDASSERT(false, "Unexpected data type: %s", Str(data_type).c_str());
|
||||
}
|
||||
@@ -42,9 +45,6 @@ int FDDataTypeSize(const FDDataType& data_type) {
|
||||
std::string Str(const Device& d) {
|
||||
std::string out;
|
||||
switch (d) {
|
||||
case Device::DEFAULT:
|
||||
out = "Device::DEFAULT";
|
||||
break;
|
||||
case Device::CPU:
|
||||
out = "Device::CPU";
|
||||
break;
|
||||
|
@@ -22,7 +22,7 @@
|
||||
|
||||
namespace fastdeploy {
|
||||
|
||||
enum FASTDEPLOY_DECL Device { DEFAULT, CPU, GPU };
|
||||
enum FASTDEPLOY_DECL Device { CPU, GPU };
|
||||
|
||||
FASTDEPLOY_DECL std::string Str(const Device& d);
|
||||
|
||||
|
@@ -71,14 +71,13 @@ void BindRuntime(pybind11::module& m) {
|
||||
std::vector<FDTensor> inputs(data.size());
|
||||
int index = 0;
|
||||
for (auto iter = data.begin(); iter != data.end(); ++iter) {
|
||||
inputs[index].dtype =
|
||||
NumpyDataTypeToFDDataType(iter->second.dtype());
|
||||
inputs[index].shape.insert(
|
||||
inputs[index].shape.begin(), iter->second.shape(),
|
||||
std::vector<int64_t> data_shape;
|
||||
data_shape.insert(data_shape.begin(), iter->second.shape(),
|
||||
iter->second.shape() + iter->second.ndim());
|
||||
auto dtype = NumpyDataTypeToFDDataType(iter->second.dtype());
|
||||
// TODO(jiangjiajun) Maybe skip memory copy is a better choice
|
||||
// use SetExternalData
|
||||
inputs[index].data.resize(iter->second.nbytes());
|
||||
inputs[index].Resize(data_shape, dtype);
|
||||
memcpy(inputs[index].MutableData(), iter->second.mutable_data(),
|
||||
iter->second.nbytes());
|
||||
inputs[index].name = iter->first;
|
||||
|
@@ -59,13 +59,15 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) {
|
||||
|
||||
void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor,
|
||||
bool share_buffer) {
|
||||
tensor->dtype = NumpyDataTypeToFDDataType(pyarray.dtype());
|
||||
tensor->shape.insert(tensor->shape.begin(), pyarray.shape(),
|
||||
auto dtype = NumpyDataTypeToFDDataType(pyarray.dtype());
|
||||
std::vector<int64_t> data_shape;
|
||||
data_shape.insert(data_shape.begin(), pyarray.shape(),
|
||||
pyarray.shape() + pyarray.ndim());
|
||||
if (share_buffer) {
|
||||
tensor->external_data_ptr = pyarray.mutable_data();
|
||||
tensor-> SetExternalData(data_shape, dtype,
|
||||
pyarray.mutable_data());
|
||||
} else {
|
||||
tensor->data.resize(pyarray.nbytes());
|
||||
tensor->Resize(data_shape, dtype);
|
||||
memcpy(tensor->MutableData(), pyarray.mutable_data(), pyarray.nbytes());
|
||||
}
|
||||
}
|
||||
|
@@ -17,6 +17,7 @@
|
||||
#include <pybind11/numpy.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "fastdeploy/fastdeploy_runtime.h"
|
||||
@@ -42,7 +43,8 @@ pybind11::array TensorToPyArray(const FDTensor& tensor);
|
||||
cv::Mat PyArrayToCvMat(pybind11::array& pyarray);
|
||||
#endif
|
||||
|
||||
template <typename T> FDDataType CTypeToFDDataType() {
|
||||
template <typename T>
|
||||
FDDataType CTypeToFDDataType() {
|
||||
if (std::is_same<T, int32_t>::value) {
|
||||
return FDDataType::INT32;
|
||||
} else if (std::is_same<T, int64_t>::value) {
|
||||
@@ -58,16 +60,17 @@ template <typename T> FDDataType CTypeToFDDataType() {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::vector<pybind11::array>
|
||||
PyBackendInfer(T& self, const std::vector<std::string>& names,
|
||||
std::vector<pybind11::array> PyBackendInfer(
|
||||
T& self, const std::vector<std::string>& names,
|
||||
std::vector<pybind11::array>& data) {
|
||||
std::vector<FDTensor> inputs(data.size());
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
// TODO(jiangjiajun) here is considered to use user memory directly
|
||||
inputs[i].dtype = NumpyDataTypeToFDDataType(data[i].dtype());
|
||||
inputs[i].shape.insert(inputs[i].shape.begin(), data[i].shape(),
|
||||
auto dtype = NumpyDataTypeToFDDataType(data[i].dtype());
|
||||
std::vector<int64_t> data_shape;
|
||||
data_shape.insert(data_shape.begin(), data[i].shape(),
|
||||
data[i].shape() + data[i].ndim());
|
||||
inputs[i].data.resize(data[i].nbytes());
|
||||
inputs[i].Resize(data_shape, dtype);
|
||||
memcpy(inputs[i].MutableData(), data[i].mutable_data(), data[i].nbytes());
|
||||
inputs[i].name = names[i];
|
||||
}
|
||||
|
Reference in New Issue
Block a user