remote HOSTDEVICE modifier

This commit is contained in:
zhoushunjie
2022-10-02 04:28:30 +00:00
parent 84ee24ff67
commit ec8b183d9c
9 changed files with 92 additions and 242 deletions

View File

@@ -13,6 +13,7 @@
// limitations under the License.
#include "fastdeploy/backends/paddle/paddle_backend.h"
#include "fastdeploy/core/float16.h"
namespace fastdeploy {
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device) {
@@ -39,6 +40,10 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
shape, place);
return;
} else if (fd_tensor.dtype == FDDataType::FP16) {
tensor->ShareExternalData(static_cast<const float16*>(fd_tensor.Data()),
shape, place);
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor.dtype).c_str());
@@ -60,6 +65,9 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
} else if (fd_tensor->dtype == FDDataType::INT64) {
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor->MutableData()));
return;
} else if (fd_tensor->dtype == FDDataType::FP16) {
tensor->CopyToCpu(static_cast<float16*>(fd_tensor->MutableData()));
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor->dtype).c_str());
@@ -77,7 +85,9 @@ FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
fd_dtype = FDDataType::UINT8;
} else if (dtype == paddle_infer::INT8) {
fd_dtype = FDDataType::INT8;
}else {
} else if (dtype == paddle_infer::FLOAT16) {
fd_dtype = FDDataType::FP16;
} else {
FDASSERT(
false,
"Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.",