mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 17:17:14 +08:00
[Serving][Backend] Backend support zero_copy_infer and Serving reduce the output memory copy (#703)
* backend add zero copy infer interface * fix bug * fix bug * fix bug * paddle ipu
This commit is contained in:
@@ -62,8 +62,11 @@ class BaseBackend {
|
||||
virtual TensorInfo GetOutputInfo(int index) = 0;
|
||||
virtual std::vector<TensorInfo> GetInputInfos() = 0;
|
||||
virtual std::vector<TensorInfo> GetOutputInfos() = 0;
|
||||
// if copy_to_fd is true, copy memory data to FDTensor
|
||||
// else share memory to FDTensor(only Paddle、ORT、TRT、OpenVINO support it)
|
||||
virtual bool Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) = 0;
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd = true) = 0;
|
||||
virtual std::unique_ptr<BaseBackend> Clone(void *stream = nullptr,
|
||||
int device_id = -1) {
|
||||
FDERROR << "Clone no support" << std::endl;
|
||||
|
@@ -187,7 +187,8 @@ TensorInfo LiteBackend::GetOutputInfo(int index) {
|
||||
std::vector<TensorInfo> LiteBackend::GetOutputInfos() { return outputs_desc_; }
|
||||
|
||||
bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd) {
|
||||
if (inputs.size() != inputs_desc_.size()) {
|
||||
FDERROR << "[LiteBackend] Size of inputs(" << inputs.size()
|
||||
<< ") should keep same with the inputs of this model("
|
||||
|
@@ -60,7 +60,9 @@ class LiteBackend : public BaseBackend {
|
||||
const std::string& params_file,
|
||||
const LiteBackendOption& option = LiteBackendOption());
|
||||
|
||||
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs) override; // NOLINT
|
||||
bool Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd = true) override; // NOLINT
|
||||
|
||||
int NumInputs() const override { return inputs_desc_.size(); }
|
||||
|
||||
|
@@ -341,7 +341,8 @@ int OpenVINOBackend::NumInputs() const { return input_infos_.size(); }
|
||||
int OpenVINOBackend::NumOutputs() const { return output_infos_.size(); }
|
||||
|
||||
bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd) {
|
||||
if (inputs.size() != input_infos_.size()) {
|
||||
FDERROR << "[OpenVINOBackend] Size of the inputs(" << inputs.size()
|
||||
<< ") should keep same with the inputs of this model("
|
||||
@@ -364,11 +365,20 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
auto out_tensor_shape = out_tensor.get_shape();
|
||||
std::vector<int64_t> shape(out_tensor_shape.begin(),
|
||||
out_tensor_shape.end());
|
||||
(*outputs)[i].Allocate(shape,
|
||||
if(copy_to_fd) {
|
||||
(*outputs)[i].Resize(shape,
|
||||
OpenVINODataTypeToFD(out_tensor.get_element_type()),
|
||||
output_infos_[i].name);
|
||||
output_infos_[i].name,
|
||||
Device::CPU);
|
||||
memcpy((*outputs)[i].MutableData(), out_tensor.data(),
|
||||
(*outputs)[i].Nbytes());
|
||||
} else {
|
||||
(*outputs)[i].name = output_infos_[i].name;
|
||||
(*outputs)[i].SetExternalData(shape,
|
||||
OpenVINODataTypeToFD(out_tensor.get_element_type()),
|
||||
out_tensor.data(),
|
||||
Device::CPU);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@@ -48,7 +48,8 @@ class OpenVINOBackend : public BaseBackend {
|
||||
const OpenVINOBackendOption& option = OpenVINOBackendOption());
|
||||
|
||||
bool Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) override;
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd = true) override;
|
||||
|
||||
int NumInputs() const override;
|
||||
|
||||
|
@@ -181,8 +181,8 @@ bool OrtBackend::InitFromOnnx(const std::string& model_file,
|
||||
return true;
|
||||
}
|
||||
|
||||
void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor,
|
||||
const std::string& name) {
|
||||
void OrtBackend::OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor,
|
||||
const std::string& name, bool copy_to_fd) {
|
||||
const auto info = value.GetTensorTypeAndShapeInfo();
|
||||
const auto data_type = info.GetElementType();
|
||||
size_t numel = info.GetElementCount();
|
||||
@@ -210,12 +210,21 @@ void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor,
|
||||
"Unrecognized data type of %d while calling OrtBackend::CopyToCpu().",
|
||||
data_type);
|
||||
}
|
||||
const void* value_ptr = value.GetTensorData<void*>();
|
||||
if (copy_to_fd) {
|
||||
tensor->Resize(shape, dtype, name);
|
||||
memcpy(tensor->MutableData(), value.GetTensorData<void*>(), numel);
|
||||
memcpy(tensor->MutableData(), value_ptr, numel);
|
||||
} else {
|
||||
tensor->name = name;
|
||||
tensor->SetExternalData(
|
||||
shape, dtype,
|
||||
const_cast<void*>(value_ptr), Device::CPU);
|
||||
}
|
||||
}
|
||||
|
||||
bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd) {
|
||||
if (inputs.size() != inputs_desc_.size()) {
|
||||
FDERROR << "[OrtBackend] Size of the inputs(" << inputs.size()
|
||||
<< ") should keep same with the inputs of this model("
|
||||
@@ -243,11 +252,12 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Copy result after inference
|
||||
// Convert result after inference
|
||||
std::vector<Ort::Value> ort_outputs = binding_->GetOutputValues();
|
||||
outputs->resize(ort_outputs.size());
|
||||
for (size_t i = 0; i < ort_outputs.size(); ++i) {
|
||||
CopyToCpu(ort_outputs[i], &((*outputs)[i]), outputs_desc_[i].name);
|
||||
OrtValueToFDTensor(ort_outputs[i], &((*outputs)[i]),
|
||||
outputs_desc_[i].name, copy_to_fd);
|
||||
}
|
||||
|
||||
return true;
|
||||
|
@@ -68,7 +68,8 @@ class OrtBackend : public BaseBackend {
|
||||
bool from_memory_buffer = false);
|
||||
|
||||
bool Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) override;
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd = true) override;
|
||||
|
||||
int NumInputs() const override { return inputs_desc_.size(); }
|
||||
|
||||
@@ -92,7 +93,7 @@ class OrtBackend : public BaseBackend {
|
||||
Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle");
|
||||
#endif
|
||||
OrtBackendOption option_;
|
||||
void CopyToCpu(const Ort::Value& value, FDTensor* tensor,
|
||||
const std::string& name);
|
||||
void OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor,
|
||||
const std::string& name, bool copy_to_fd);
|
||||
};
|
||||
} // namespace fastdeploy
|
||||
|
@@ -194,7 +194,8 @@ std::vector<TensorInfo> PaddleBackend::GetOutputInfos() {
|
||||
}
|
||||
|
||||
bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd) {
|
||||
if (inputs.size() != inputs_desc_.size()) {
|
||||
FDERROR << "[PaddleBackend] Size of inputs(" << inputs.size()
|
||||
<< ") should keep same with the inputs of this model("
|
||||
@@ -208,11 +209,18 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
}
|
||||
|
||||
predictor_->Run();
|
||||
|
||||
// output share backend memory only support CPU or GPU
|
||||
if(option_.use_ipu) {
|
||||
copy_to_fd = true;
|
||||
}
|
||||
outputs->resize(outputs_desc_.size());
|
||||
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
|
||||
auto handle = predictor_->GetOutputHandle(outputs_desc_[i].name);
|
||||
if(copy_to_fd) {
|
||||
(*outputs)[i].is_pinned_memory = option_.enable_pinned_memory;
|
||||
CopyTensorToCpu(handle, &((*outputs)[i]));
|
||||
}
|
||||
PaddleTensorToFDTensor(handle, &((*outputs)[i]), copy_to_fd);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@@ -87,9 +87,12 @@ paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device);
|
||||
// Share memory buffer with paddle_infer::Tensor from fastdeploy::FDTensor
|
||||
void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor, FDTensor& fd_tensor);
|
||||
|
||||
// Copy memory data from paddle_infer::Tensor to fastdeploy::FDTensor
|
||||
void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||
FDTensor* fd_tensor);
|
||||
// convert paddle_infer::Tensor to fastdeploy::FDTensor
|
||||
// if copy_to_fd is true, copy memory data to FDTensor
|
||||
/// else share memory to FDTensor
|
||||
void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||
FDTensor* fd_tensor,
|
||||
bool copy_to_fd);
|
||||
|
||||
// Convert data type from paddle inference to fastdeploy
|
||||
FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype);
|
||||
@@ -108,7 +111,9 @@ class PaddleBackend : public BaseBackend {
|
||||
const PaddleBackendOption& option = PaddleBackendOption());
|
||||
|
||||
bool Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) override;
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd = true) override;
|
||||
|
||||
|
||||
int NumInputs() const override { return inputs_desc_.size(); }
|
||||
|
||||
|
@@ -61,12 +61,14 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
|
||||
Str(fd_tensor.dtype).c_str());
|
||||
}
|
||||
|
||||
void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||
FDTensor* fd_tensor) {
|
||||
void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||
FDTensor* fd_tensor,
|
||||
bool copy_to_fd) {
|
||||
auto fd_dtype = PaddleDataTypeToFD(tensor->type());
|
||||
std::vector<int64_t> shape;
|
||||
auto tmp_shape = tensor->shape();
|
||||
shape.assign(tmp_shape.begin(), tmp_shape.end());
|
||||
if(copy_to_fd) {
|
||||
fd_tensor->Resize(shape, fd_dtype, tensor->name());
|
||||
if (fd_tensor->dtype == FDDataType::FP32) {
|
||||
tensor->CopyToCpu(static_cast<float*>(fd_tensor->MutableData()));
|
||||
@@ -80,6 +82,20 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
|
||||
}
|
||||
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
|
||||
Str(fd_tensor->dtype).c_str());
|
||||
} else {
|
||||
paddle_infer::PlaceType place;
|
||||
int size = 0;
|
||||
// TODO(liqi): The tensor->data interface of paddle don't return device id
|
||||
// and don't support return void*.
|
||||
auto* out_data = tensor->data<uint8_t>(&place, &size);
|
||||
Device device = Device::CPU;
|
||||
if(place == paddle_infer::PlaceType::kGPU) {
|
||||
device = Device::GPU;
|
||||
}
|
||||
fd_tensor->SetExternalData(
|
||||
shape, fd_dtype,
|
||||
reinterpret_cast<void*>(out_data), device);
|
||||
}
|
||||
}
|
||||
|
||||
FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
|
||||
|
@@ -188,7 +188,8 @@ bool PorosBackend::InitFromPoros(const std::string& model_file,
|
||||
}
|
||||
|
||||
bool PorosBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd) {
|
||||
// Convert FD Tensor to PyTorch Tensor
|
||||
std::vector<torch::jit::IValue> poros_inputs;
|
||||
bool is_backend_cuda =
|
||||
|
@@ -85,7 +85,9 @@ class PorosBackend : public BaseBackend {
|
||||
std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
||||
const PorosBackendOption& option = PorosBackendOption());
|
||||
|
||||
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs);
|
||||
bool Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd = true) override;
|
||||
|
||||
int NumInputs() const { return _numinputs; }
|
||||
|
||||
|
@@ -289,7 +289,8 @@ std::vector<TensorInfo> RKNPU2Backend::GetOutputInfos() {
|
||||
}
|
||||
|
||||
bool RKNPU2Backend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd) {
|
||||
int ret = RKNN_SUCC;
|
||||
// Judge whether the input and output size are the same
|
||||
if (inputs.size() != inputs_desc_.size()) {
|
||||
|
@@ -72,7 +72,8 @@ class RKNPU2Backend : public BaseBackend {
|
||||
std::vector<TensorInfo> GetInputInfos() override;
|
||||
std::vector<TensorInfo> GetOutputInfos() override;
|
||||
bool Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) override;
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd = true) override;
|
||||
|
||||
private:
|
||||
// The object of rknn context.
|
||||
|
@@ -283,7 +283,8 @@ int TrtBackend::ShapeRangeInfoUpdated(const std::vector<FDTensor>& inputs) {
|
||||
}
|
||||
|
||||
bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs) {
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd) {
|
||||
if (inputs.size() != NumInputs()) {
|
||||
FDERROR << "Require " << NumInputs() << "inputs, but get " << inputs.size()
|
||||
<< "." << std::endl;
|
||||
@@ -304,7 +305,7 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
|
||||
cudaSetDevice(option_.gpu_id);
|
||||
SetInputs(inputs);
|
||||
AllocateOutputsBuffer(outputs);
|
||||
AllocateOutputsBuffer(outputs, copy_to_fd);
|
||||
|
||||
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
|
||||
FDERROR << "Failed to Infer with TensorRT." << std::endl;
|
||||
@@ -323,6 +324,11 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
casted_output_tensors_[(*outputs)[i].name].Resize((*outputs)[i].shape, (*outputs)[i].dtype,
|
||||
(*outputs)[i].name, Device::GPU);
|
||||
function::CudaCast(output_tensor, &casted_output_tensors_[(*outputs)[i].name], stream_);
|
||||
if(!copy_to_fd) {
|
||||
(*outputs)[i].SetExternalData((*outputs)[i].shape, model_output_dtype,
|
||||
casted_output_tensors_[(*outputs)[i].name].MutableData(),
|
||||
Device::GPU, option_.gpu_id);
|
||||
}
|
||||
} else {
|
||||
casted_output_tensors_[(*outputs)[i].name].SetExternalData(
|
||||
(*outputs)[i].shape, model_output_dtype,
|
||||
@@ -330,6 +336,7 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
Device::GPU);
|
||||
}
|
||||
}
|
||||
if (copy_to_fd) {
|
||||
for (size_t i = 0; i < outputs->size(); ++i) {
|
||||
FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(),
|
||||
casted_output_tensors_[(*outputs)[i].name].Data(),
|
||||
@@ -339,6 +346,7 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
|
||||
}
|
||||
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
|
||||
"[ERROR] Error occurs while sync cuda stream.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
@@ -427,7 +435,8 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
|
||||
}
|
||||
}
|
||||
|
||||
void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
|
||||
void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd) {
|
||||
if (outputs->size() != outputs_desc_.size()) {
|
||||
outputs->resize(outputs_desc_.size());
|
||||
}
|
||||
@@ -446,18 +455,26 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
|
||||
outputs_desc_[i].name.c_str());
|
||||
auto ori_idx = iter->second;
|
||||
|
||||
// set user's outputs info
|
||||
std::vector<int64_t> shape(output_dims.d,
|
||||
output_dims.d + output_dims.nbDims);
|
||||
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
|
||||
(*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype,
|
||||
outputs_desc_[i].name);
|
||||
|
||||
// Allocate output buffer memory
|
||||
outputs_device_buffer_[outputs_desc_[i].name].resize(output_dims);
|
||||
|
||||
// binding output buffer
|
||||
bindings_[idx] = outputs_device_buffer_[outputs_desc_[i].name].data();
|
||||
|
||||
// set user's outputs info
|
||||
std::vector<int64_t> shape(output_dims.d,
|
||||
output_dims.d + output_dims.nbDims);
|
||||
if(copy_to_fd) {
|
||||
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
|
||||
(*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype,
|
||||
outputs_desc_[i].name);
|
||||
} else {
|
||||
(*outputs)[ori_idx].name = outputs_desc_[i].name;
|
||||
(*outputs)[ori_idx].SetExternalData(
|
||||
shape, outputs_desc_[i].original_dtype,
|
||||
bindings_[idx], Device::GPU,
|
||||
option_.gpu_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@@ -97,7 +97,9 @@ class TrtBackend : public BaseBackend {
|
||||
bool InitFromOnnx(const std::string& model_file,
|
||||
const TrtBackendOption& option = TrtBackendOption(),
|
||||
bool from_memory_buffer = false);
|
||||
bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs);
|
||||
bool Infer(std::vector<FDTensor>& inputs,
|
||||
std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd = true) override;
|
||||
|
||||
int NumInputs() const { return inputs_desc_.size(); }
|
||||
int NumOutputs() const { return outputs_desc_.size(); }
|
||||
@@ -162,7 +164,8 @@ class TrtBackend : public BaseBackend {
|
||||
bool LoadTrtCache(const std::string& trt_engine_file);
|
||||
int ShapeRangeInfoUpdated(const std::vector<FDTensor>& inputs);
|
||||
void SetInputs(const std::vector<FDTensor>& inputs);
|
||||
void AllocateOutputsBuffer(std::vector<FDTensor>* outputs);
|
||||
void AllocateOutputsBuffer(std::vector<FDTensor>* outputs,
|
||||
bool copy_to_fd = true);
|
||||
};
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
@@ -81,11 +81,13 @@ const void* FDTensor::CpuData() const {
|
||||
|
||||
void FDTensor::SetExternalData(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type, void* data_buffer,
|
||||
const Device& new_device) {
|
||||
const Device& new_device,
|
||||
int new_device_id) {
|
||||
dtype = data_type;
|
||||
shape.assign(new_shape.begin(), new_shape.end());
|
||||
external_data_ptr = data_buffer;
|
||||
device = new_device;
|
||||
device_id = new_device_id;
|
||||
}
|
||||
|
||||
void FDTensor::ExpandDim(int64_t axis) {
|
||||
@@ -316,6 +318,8 @@ void FDTensor::FreeFn() {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(liqi): no src_device and dst_device
|
||||
// should support copy from cpu or gpu to cpu or gpu
|
||||
void FDTensor::CopyBuffer(void* dst, const void* src, size_t nbytes,
|
||||
const Device& device, bool is_pinned_memory) {
|
||||
if (device == Device::GPU) {
|
||||
@@ -383,7 +387,8 @@ FDTensor::FDTensor(const Scalar& scalar) {
|
||||
|
||||
FDTensor::FDTensor(const FDTensor& other)
|
||||
: shape(other.shape), name(other.name), dtype(other.dtype),
|
||||
device(other.device), external_data_ptr(other.external_data_ptr) {
|
||||
device(other.device), external_data_ptr(other.external_data_ptr),
|
||||
device_id(other.device_id) {
|
||||
// Copy buffer
|
||||
if (other.buffer_ == nullptr) {
|
||||
buffer_ = nullptr;
|
||||
@@ -398,7 +403,8 @@ FDTensor::FDTensor(const FDTensor& other)
|
||||
FDTensor::FDTensor(FDTensor&& other)
|
||||
: buffer_(other.buffer_), shape(std::move(other.shape)),
|
||||
name(std::move(other.name)), dtype(other.dtype),
|
||||
external_data_ptr(other.external_data_ptr), device(other.device) {
|
||||
external_data_ptr(other.external_data_ptr), device(other.device),
|
||||
device_id(other.device_id) {
|
||||
other.name = "";
|
||||
// Note(zhoushunjie): Avoid double free.
|
||||
other.buffer_ = nullptr;
|
||||
@@ -408,6 +414,7 @@ FDTensor::FDTensor(FDTensor&& other)
|
||||
FDTensor& FDTensor::operator=(const FDTensor& other) {
|
||||
if (&other != this) {
|
||||
// Copy buffer
|
||||
device_id = other.device_id;
|
||||
if (other.buffer_ == nullptr) {
|
||||
FreeFn();
|
||||
buffer_ = nullptr;
|
||||
@@ -435,6 +442,7 @@ FDTensor& FDTensor::operator=(FDTensor&& other) {
|
||||
name = std::move(other.name);
|
||||
dtype = other.dtype;
|
||||
device = other.device;
|
||||
device_id = other.device_id;
|
||||
|
||||
other.name = "";
|
||||
// Note(zhoushunjie): Avoid double free.
|
||||
|
@@ -78,7 +78,8 @@ struct FASTDEPLOY_DECL FDTensor {
|
||||
// So take care with the user buffer
|
||||
void SetExternalData(const std::vector<int64_t>& new_shape,
|
||||
const FDDataType& data_type, void* data_buffer,
|
||||
const Device& new_device = Device::CPU);
|
||||
const Device& new_device = Device::CPU,
|
||||
int new_device_id = -1);
|
||||
|
||||
// Expand the shape of a Tensor. Insert a new axis that will appear
|
||||
// at the `axis` position in the expanded Tensor shape.
|
||||
|
@@ -580,6 +580,39 @@ bool Runtime::Infer(std::vector<FDTensor>& input_tensors,
|
||||
return backend_->Infer(input_tensors, output_tensors);
|
||||
}
|
||||
|
||||
bool Runtime::Infer() {
|
||||
return backend_->Infer(input_tensors_, &output_tensors_, false);
|
||||
}
|
||||
|
||||
void Runtime::BindInputTensor(const std::string& name, FDTensor& input) {
|
||||
bool is_exist = false;
|
||||
for (auto& t : input_tensors_) {
|
||||
if (t.name == name) {
|
||||
is_exist = true;
|
||||
t.SetExternalData(input.shape, input.dtype,
|
||||
input.MutableData(), input.device,
|
||||
input.device_id);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!is_exist) {
|
||||
FDTensor new_tensor(name);
|
||||
new_tensor.SetExternalData(input.shape, input.dtype,
|
||||
input.MutableData(), input.device,
|
||||
input.device_id);
|
||||
input_tensors_.emplace_back(std::move(new_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
FDTensor* Runtime::GetOutputTensor(const std::string& name) {
|
||||
for (auto& t : output_tensors_) {
|
||||
if (t.name == name) {
|
||||
return &t;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void Runtime::CreatePaddleBackend() {
|
||||
#ifdef ENABLE_PADDLE_BACKEND
|
||||
auto pd_option = PaddleBackendOption();
|
||||
|
14
fastdeploy/runtime.h
Executable file → Normal file
14
fastdeploy/runtime.h
Executable file → Normal file
@@ -405,6 +405,12 @@ struct FASTDEPLOY_DECL Runtime {
|
||||
bool Infer(std::vector<FDTensor>& input_tensors,
|
||||
std::vector<FDTensor>* output_tensors);
|
||||
|
||||
/** \brief No params inference the model.
|
||||
*
|
||||
* the input and output data need to pass through the BindInputTensor and GetOutputTensor interfaces.
|
||||
*/
|
||||
bool Infer();
|
||||
|
||||
/** \brief Compile TorchScript Module, only for Poros backend
|
||||
*
|
||||
* \param[in] prewarm_tensors Prewarm datas for compile
|
||||
@@ -432,6 +438,12 @@ struct FASTDEPLOY_DECL Runtime {
|
||||
/** \brief Get all the output information
|
||||
*/
|
||||
std::vector<TensorInfo> GetOutputInfos();
|
||||
/** \brief Bind FDTensor by name, no copy and share input memory
|
||||
*/
|
||||
void BindInputTensor(const std::string& name, FDTensor& input);
|
||||
/** \brief Get output FDTensor by name, no copy and share backend output memory
|
||||
*/
|
||||
FDTensor* GetOutputTensor(const std::string& name);
|
||||
|
||||
/** \brief Clone new Runtime when multiple instances of the same model are created
|
||||
*
|
||||
@@ -451,5 +463,7 @@ struct FASTDEPLOY_DECL Runtime {
|
||||
void CreateLiteBackend();
|
||||
void CreateRKNPU2Backend();
|
||||
std::unique_ptr<BaseBackend> backend_;
|
||||
std::vector<FDTensor> input_tensors_;
|
||||
std::vector<FDTensor> output_tensors_;
|
||||
};
|
||||
} // namespace fastdeploy
|
||||
|
@@ -607,9 +607,6 @@ class ModelInstanceState : public BackendModelInstance {
|
||||
std::vector<std::string> output_names_;
|
||||
std::vector<fastdeploy::TensorInfo> input_tensor_infos_;
|
||||
std::vector<fastdeploy::TensorInfo> output_tensor_infos_;
|
||||
|
||||
std::vector<fastdeploy::FDTensor> input_tensors_;
|
||||
std::vector<fastdeploy::FDTensor> output_tensors_;
|
||||
};
|
||||
|
||||
TRITONSERVER_Error* ModelInstanceState::Create(
|
||||
@@ -647,8 +644,6 @@ ModelInstanceState::~ModelInstanceState() { ReleaseRunResources(); }
|
||||
void ModelInstanceState::ReleaseRunResources() {
|
||||
input_names_.clear();
|
||||
output_names_.clear();
|
||||
input_tensors_.clear();
|
||||
output_tensors_.clear();
|
||||
input_tensor_infos_.clear();
|
||||
output_tensor_infos_.clear();
|
||||
}
|
||||
@@ -671,9 +666,7 @@ TRITONSERVER_Error* ModelInstanceState::ValidateInputs() {
|
||||
input_tensor_infos_ = runtime_->GetInputInfos();
|
||||
std::vector<std::string> names;
|
||||
GetInfoNames(input_tensor_infos_, names);
|
||||
input_tensors_.clear();
|
||||
input_names_.clear();
|
||||
input_tensors_.reserve(input_tensor_infos_.size());
|
||||
|
||||
triton::common::TritonJson::Value ios;
|
||||
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("input", &ios));
|
||||
@@ -700,7 +693,6 @@ TRITONSERVER_Error* ModelInstanceState::ValidateInputs() {
|
||||
std::set<std::string> inames(names.begin(), names.end());
|
||||
RETURN_IF_ERROR(CheckAllowedModelInput(io, inames));
|
||||
}
|
||||
input_tensors_.emplace_back(io_name);
|
||||
|
||||
auto fd_data_type = ModelConfigDataTypeToFDType(io_dtype);
|
||||
if (fd_data_type == fastdeploy::FDDataType::UNKNOWN1) {
|
||||
@@ -759,11 +751,8 @@ TRITONSERVER_Error* ModelInstanceState::ValidateInputs() {
|
||||
|
||||
TRITONSERVER_Error* ModelInstanceState::ValidateOutputs() {
|
||||
output_tensor_infos_ = runtime_->GetOutputInfos();
|
||||
output_tensors_.clear();
|
||||
output_tensors_.reserve(output_tensor_infos_.size());
|
||||
std::set<std::string> out_names;
|
||||
for (const auto& info : output_tensor_infos_) {
|
||||
output_tensors_.emplace_back(info.name);
|
||||
out_names.insert(info.name);
|
||||
}
|
||||
output_names_.clear();
|
||||
@@ -793,7 +782,6 @@ TRITONSERVER_Error* ModelInstanceState::ValidateOutputs() {
|
||||
if (index < 0) {
|
||||
RETURN_IF_ERROR(CheckAllowedModelInput(io, out_names));
|
||||
}
|
||||
// output_tensors_.emplace_back(io_name);
|
||||
|
||||
auto fd_data_type = ModelConfigDataTypeToFDType(io_dtype);
|
||||
if (fd_data_type == fastdeploy::FDDataType::UNKNOWN1) {
|
||||
@@ -1009,7 +997,7 @@ void ModelInstanceState::ProcessRequests(TRITONBACKEND_Request** requests,
|
||||
TRITONSERVER_Error* ModelInstanceState::Run(
|
||||
std::vector<TRITONBACKEND_Response*>* responses,
|
||||
const uint32_t response_count) {
|
||||
runtime_->Infer(input_tensors_, &output_tensors_);
|
||||
runtime_->Infer();
|
||||
#ifdef TRITON_ENABLE_GPU
|
||||
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
|
||||
cudaStreamSynchronize(CudaStream());
|
||||
@@ -1042,18 +1030,7 @@ TRITONSERVER_Error* ModelInstanceState::SetInputTensors(
|
||||
input, &input_name, &input_datatype, &input_shape, &input_dims_count,
|
||||
nullptr, nullptr));
|
||||
|
||||
int index = GetInfoIndex(std::string(input_name), input_tensor_infos_);
|
||||
if (index < 0) {
|
||||
auto err = TRITONSERVER_ErrorNew(
|
||||
TRITONSERVER_ERROR_INTERNAL,
|
||||
(std::string("Input name [") + input_name +
|
||||
std::string("] is not one of the FD predictor input: ") +
|
||||
input_tensors_[index].name)
|
||||
.c_str());
|
||||
// SendErrorForResponses(responses, request_count, err);
|
||||
return err;
|
||||
}
|
||||
|
||||
std::string in_name = std::string(input_name);
|
||||
std::vector<int64_t> batchn_shape;
|
||||
// For a ragged input tensor, the tensor shape should be
|
||||
// the flatten shape of the whole batch
|
||||
@@ -1082,23 +1059,40 @@ TRITONSERVER_Error* ModelInstanceState::SetInputTensors(
|
||||
}
|
||||
}
|
||||
|
||||
const char* input_buffer;
|
||||
size_t batchn_byte_size;
|
||||
TRITONSERVER_MemoryType memory_type;
|
||||
int64_t device_id = 0;
|
||||
fastdeploy::Device device;
|
||||
int64_t memory_type_id;
|
||||
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>>
|
||||
allowed_input_types;
|
||||
if (Kind() == TRITONSERVER_INSTANCEGROUPKIND_GPU) {
|
||||
memory_type = TRITONSERVER_MEMORY_GPU;
|
||||
allowed_input_types = {{TRITONSERVER_MEMORY_GPU, DeviceId()},
|
||||
{TRITONSERVER_MEMORY_CPU_PINNED, 0},
|
||||
{TRITONSERVER_MEMORY_CPU, 0}};
|
||||
} else {
|
||||
allowed_input_types = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
|
||||
{TRITONSERVER_MEMORY_CPU, 0}};
|
||||
}
|
||||
|
||||
RETURN_IF_ERROR(
|
||||
collector->ProcessTensor(
|
||||
input_name, nullptr, 0, allowed_input_types, &input_buffer,
|
||||
&batchn_byte_size, &memory_type, &memory_type_id));
|
||||
|
||||
int32_t device_id = -1;
|
||||
fastdeploy::Device device;
|
||||
if (memory_type == TRITONSERVER_MEMORY_GPU) {
|
||||
device_id = DeviceId();
|
||||
device = fastdeploy::Device::GPU;
|
||||
} else {
|
||||
memory_type = TRITONSERVER_MEMORY_CPU;
|
||||
device = fastdeploy::Device::CPU;
|
||||
}
|
||||
input_tensors_[index].Resize(
|
||||
batchn_shape, ConvertDataTypeToFD(input_datatype), input_name, device);
|
||||
collector->ProcessTensor(
|
||||
input_name,
|
||||
reinterpret_cast<char*>(input_tensors_[index].MutableData()),
|
||||
input_tensors_[index].Nbytes(), memory_type, device_id);
|
||||
|
||||
fastdeploy::FDTensor fdtensor(in_name);
|
||||
fdtensor.SetExternalData(
|
||||
batchn_shape, ConvertDataTypeToFD(input_datatype),
|
||||
const_cast<char*>(input_buffer), device, device_id);
|
||||
runtime_->BindInputTensor(in_name, fdtensor);
|
||||
}
|
||||
|
||||
// Finalize...
|
||||
@@ -1134,12 +1128,25 @@ TRITONSERVER_Error* ModelInstanceState::ReadOutputTensors(
|
||||
// }
|
||||
|
||||
for (auto& output_name : output_names_) {
|
||||
int idx = GetInfoIndex(output_name, output_tensor_infos_);
|
||||
auto* output_tensor = runtime_->GetOutputTensor(output_name);
|
||||
if (output_tensor == nullptr) {
|
||||
RETURN_IF_ERROR(
|
||||
TRITONSERVER_ErrorNew(
|
||||
TRITONSERVER_ERROR_INTERNAL,
|
||||
(std::string("output tensor '") + output_name + "' is not found")
|
||||
.c_str()));
|
||||
}
|
||||
TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU;
|
||||
int64_t memory_type_id = 0;
|
||||
if(output_tensor->device == fastdeploy::Device::GPU) {
|
||||
memory_type = TRITONSERVER_MEMORY_GPU;
|
||||
memory_type_id = DeviceId();
|
||||
}
|
||||
responder.ProcessTensor(
|
||||
output_tensors_[idx].name, ConvertFDType(output_tensors_[idx].dtype),
|
||||
output_tensors_[idx].shape,
|
||||
reinterpret_cast<char*>(output_tensors_[idx].MutableData()),
|
||||
TRITONSERVER_MEMORY_CPU, 0);
|
||||
output_tensor->name, ConvertFDType(output_tensor->dtype),
|
||||
output_tensor->shape,
|
||||
reinterpret_cast<char*>(output_tensor->MutableData()),
|
||||
memory_type, memory_type_id);
|
||||
}
|
||||
|
||||
// Finalize and wait for any pending buffer copies.
|
||||
|
Reference in New Issue
Block a user