mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-07 09:31:35 +08:00
[Paddle Lite] Support stable-diffusion model (#830)
* support stable-diffusion model for paddlelite * update code
This commit is contained in:
@@ -206,24 +206,24 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
// Adjust dims only, allocate lazy.
|
// Adjust dims only, allocate lazy.
|
||||||
tensor->Resize(inputs[i].shape);
|
tensor->Resize(inputs[i].shape);
|
||||||
if (inputs[i].dtype == FDDataType::FP32) {
|
if (inputs[i].dtype == FDDataType::FP32) {
|
||||||
tensor->CopyFromCpu<float, paddle::lite_api::TargetType::kARM>(
|
tensor->CopyFromCpu<float, paddle::lite_api::TargetType::kHost>(
|
||||||
reinterpret_cast<const float*>(const_cast<void*>(
|
reinterpret_cast<const float*>(const_cast<void*>(
|
||||||
inputs[i].CpuData())));
|
inputs[i].CpuData())));
|
||||||
} else if (inputs[i].dtype == FDDataType::INT32) {
|
} else if (inputs[i].dtype == FDDataType::INT32) {
|
||||||
tensor->CopyFromCpu<int, paddle::lite_api::TargetType::kARM>(
|
tensor->CopyFromCpu<int, paddle::lite_api::TargetType::kHost>(
|
||||||
reinterpret_cast<const int*>(const_cast<void*>(
|
reinterpret_cast<const int*>(const_cast<void*>(
|
||||||
inputs[i].CpuData())));
|
inputs[i].CpuData())));
|
||||||
} else if (inputs[i].dtype == FDDataType::INT8) {
|
} else if (inputs[i].dtype == FDDataType::INT8) {
|
||||||
tensor->CopyFromCpu<int8_t, paddle::lite_api::TargetType::kARM>(
|
tensor->CopyFromCpu<int8_t, paddle::lite_api::TargetType::kHost>(
|
||||||
reinterpret_cast<const int8_t*>(const_cast<void*>(
|
reinterpret_cast<const int8_t*>(const_cast<void*>(
|
||||||
inputs[i].CpuData())));
|
inputs[i].CpuData())));
|
||||||
} else if (inputs[i].dtype == FDDataType::UINT8) {
|
} else if (inputs[i].dtype == FDDataType::UINT8) {
|
||||||
tensor->CopyFromCpu<uint8_t, paddle::lite_api::TargetType::kARM>(
|
tensor->CopyFromCpu<uint8_t, paddle::lite_api::TargetType::kHost>(
|
||||||
reinterpret_cast<const uint8_t*>(const_cast<void*>(
|
reinterpret_cast<const uint8_t*>(const_cast<void*>(
|
||||||
inputs[i].CpuData())));
|
inputs[i].CpuData())));
|
||||||
} else if (inputs[i].dtype == FDDataType::INT64) {
|
} else if (inputs[i].dtype == FDDataType::INT64) {
|
||||||
#ifdef __aarch64__
|
#ifdef __aarch64__
|
||||||
tensor->CopyFromCpu<int64_t, paddle::lite_api::TargetType::kARM>(
|
tensor->CopyFromCpu<int64_t, paddle::lite_api::TargetType::kHost>(
|
||||||
reinterpret_cast<const int64_t*>(const_cast<void*>(
|
reinterpret_cast<const int64_t*>(const_cast<void*>(
|
||||||
inputs[i].CpuData())));
|
inputs[i].CpuData())));
|
||||||
#else
|
#else
|
||||||
|
@@ -35,6 +35,8 @@ pybind11::dtype FDDataTypeToNumpyDataType(const FDDataType& fd_dtype) {
|
|||||||
dt = pybind11::dtype::of<double>();
|
dt = pybind11::dtype::of<double>();
|
||||||
} else if (fd_dtype == FDDataType::UINT8) {
|
} else if (fd_dtype == FDDataType::UINT8) {
|
||||||
dt = pybind11::dtype::of<uint8_t>();
|
dt = pybind11::dtype::of<uint8_t>();
|
||||||
|
} else if (fd_dtype == FDDataType::INT8) {
|
||||||
|
dt = pybind11::dtype::of<int8_t>();
|
||||||
} else if (fd_dtype == FDDataType::FP16) {
|
} else if (fd_dtype == FDDataType::FP16) {
|
||||||
dt = pybind11::dtype::of<float16>();
|
dt = pybind11::dtype::of<float16>();
|
||||||
} else {
|
} else {
|
||||||
@@ -55,12 +57,14 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) {
|
|||||||
return FDDataType::FP64;
|
return FDDataType::FP64;
|
||||||
} else if (np_dtype.is(pybind11::dtype::of<uint8_t>())) {
|
} else if (np_dtype.is(pybind11::dtype::of<uint8_t>())) {
|
||||||
return FDDataType::UINT8;
|
return FDDataType::UINT8;
|
||||||
|
} else if (np_dtype.is(pybind11::dtype::of<int8_t>())) {
|
||||||
|
return FDDataType::INT8;
|
||||||
} else if (np_dtype.is(pybind11::dtype::of<float16>())) {
|
} else if (np_dtype.is(pybind11::dtype::of<float16>())) {
|
||||||
return FDDataType::FP16;
|
return FDDataType::FP16;
|
||||||
}
|
}
|
||||||
FDASSERT(false,
|
FDASSERT(false,
|
||||||
"NumpyDataTypeToFDDataType() only support "
|
"NumpyDataTypeToFDDataType() only support "
|
||||||
"int32/int64/float32/float64/float16 now.");
|
"int8/int32/int64/float32/float64/float16 now.");
|
||||||
return FDDataType::FP32;
|
return FDDataType::FP32;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -67,9 +67,11 @@ FDDataType CTypeToFDDataType() {
|
|||||||
return FDDataType::FP32;
|
return FDDataType::FP32;
|
||||||
} else if (std::is_same<T, double>::value) {
|
} else if (std::is_same<T, double>::value) {
|
||||||
return FDDataType::FP64;
|
return FDDataType::FP64;
|
||||||
|
} else if (std::is_same<T, int8_t>::value) {
|
||||||
|
return FDDataType::INT8;
|
||||||
}
|
}
|
||||||
FDASSERT(false,
|
FDASSERT(false, "CTypeToFDDataType only support "
|
||||||
"CTypeToFDDataType only support int32/int64/float32/float64 now.");
|
"int8/int32/int64/float32/float64 now.");
|
||||||
return FDDataType::FP32;
|
return FDDataType::FP32;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user