prebind output by shareExternalData

This commit is contained in:
wwbitejotunn
2023-02-13 03:11:31 +00:00
parent 59c5fedc36
commit abfa9fd850
8 changed files with 174 additions and 39 deletions

View File

@@ -61,6 +61,43 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
Str(fd_tensor.dtype).c_str());
}
void ShareOutTensorFromFDTensor(paddle_infer::Tensor* tensor,
FDTensor& fd_tensor) {
std::vector<int> shape(fd_tensor.shape.begin(), fd_tensor.shape.end());
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
if (fd_tensor.dtype == FDDataType::FP32) {
if (place == paddle_infer::PlaceType::kGPU) {
tensor->ShareExternalData(static_cast<float*>(fd_tensor.MutableData()),
shape, place);
} else {
tensor->CopyToCpu(static_cast<float*>(fd_tensor.MutableData()));
}
return;
} else if (fd_tensor.dtype == FDDataType::INT32) {
if (place == paddle_infer::PlaceType::kGPU) {
tensor->ShareExternalData(static_cast<int32_t*>(fd_tensor.MutableData()),
shape, place);
} else {
tensor->CopyToCpu(static_cast<int32_t*>(fd_tensor.MutableData()));
}
return;
} else if (fd_tensor.dtype == FDDataType::INT64) {
if (place == paddle_infer::PlaceType::kGPU) {
tensor->ShareExternalData(static_cast<int64_t*>(fd_tensor.MutableData()),
shape, place);
} else {
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor.MutableData()));
}
return;
} else if (fd_tensor.dtype == FDDataType::UINT8) {
tensor->ShareExternalData(static_cast<uint8_t*>(fd_tensor.MutableData()),
shape, paddle_infer::PlaceType::kCPU);
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor.dtype).c_str());
}
void PaddleTensorToFDTensor(std::unique_ptr<paddle_infer::Tensor>& tensor,
FDTensor* fd_tensor, bool copy_to_fd) {
auto fd_dtype = PaddleDataTypeToFD(tensor->type());