remote HOSTDEVICE modifier

This commit is contained in:
zhoushunjie
2022-10-02 04:28:30 +00:00
parent 84ee24ff67
commit ec8b183d9c
9 changed files with 92 additions and 242 deletions

View File

@@ -35,6 +35,8 @@ std::vector<int64_t> PartialShapeToVec(const ov::PartialShape& shape) {
FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) {
if (type == ov::element::f32) {
return FDDataType::FP32;
} else if (type == ov::element::f16) {
return FDDataType::FP16;
} else if (type == ov::element::f64) {
return FDDataType::FP64;
} else if (type == ov::element::i8) {
@@ -46,7 +48,7 @@ FDDataType OpenVINODataTypeToFD(const ov::element::Type& type) {
} else if (type == ov::element::i64) {
return FDDataType::INT64;
} else {
FDASSERT(false, "Only support float/double/int8/int32/int64 now.");
FDASSERT(false, "Only support float/double/int8/int32/int64/float16 now.");
}
return FDDataType::FP32;
}
@@ -64,8 +66,11 @@ ov::element::Type FDDataTypeToOV(const FDDataType& type) {
return ov::element::i32;
} else if (type == FDDataType::INT64) {
return ov::element::i64;
} else if (type == FDDataType::FP16) {
return ov::element::f16;
}
FDASSERT(false, "Only support float/double/int8/uint8/int32/int64 now.");
FDASSERT(false,
"Only support float/double/int8/uint8/int32/int64/float16 now.");
return ov::element::f32;
}

View File

@@ -18,6 +18,7 @@
#include "fastdeploy/backends/ort/ops/multiclass_nms.h"
#include "fastdeploy/backends/ort/utils.h"
#include "fastdeploy/core/float16.h"
#include "fastdeploy/utils/utils.h"
#ifdef ENABLE_PADDLE_FRONTEND
#include "paddle2onnx/converter.h"
@@ -186,6 +187,9 @@ void OrtBackend::CopyToCpu(const Ort::Value& value, FDTensor* tensor,
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
dtype = FDDataType::FP64;
numel *= sizeof(double);
} else if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
dtype = FDDataType::FP16;
numel *= sizeof(float16);
} else {
FDASSERT(
false,

View File

@@ -30,6 +30,8 @@ ONNXTensorElementDataType GetOrtDtype(const FDDataType& fd_dtype) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
} else if (fd_dtype == FDDataType::INT8) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
} else if (fd_dtype == FDDataType::FP16) {
return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
}
FDERROR << "Unrecognized fastdeply data type:" << Str(fd_dtype) << "."
<< std::endl;
@@ -45,6 +47,8 @@ FDDataType GetFdDtype(const ONNXTensorElementDataType& ort_dtype) {
return FDDataType::INT32;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
return FDDataType::INT64;
} else if (ort_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) {
return FDDataType::FP16;
}
FDERROR << "Unrecognized ort data type:" << ort_dtype << "." << std::endl;
return FDDataType::FP32;

View File

@@ -13,6 +13,7 @@
// limitations under the License.
#include "fastdeploy/backends/paddle/paddle_backend.h"
#include "fastdeploy/core/float16.h"
namespace fastdeploy {
paddle_infer::PlaceType ConvertFDDeviceToPlace(Device device) {
@@ -39,6 +40,10 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
shape, place);
return;
} else if (fd_tensor.dtype == FDDataType::FP16) {
tensor->ShareExternalData(static_cast<const float16*>(fd_tensor.Data()),
shape, place);
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor.dtype).c_str());
@@ -60,6 +65,9 @@ void CopyTensorToCpu(std::unique_ptr<paddle_infer::Tensor>& tensor,
} else if (fd_tensor->dtype == FDDataType::INT64) {
tensor->CopyToCpu(static_cast<int64_t*>(fd_tensor->MutableData()));
return;
} else if (fd_tensor->dtype == FDDataType::FP16) {
tensor->CopyToCpu(static_cast<float16*>(fd_tensor->MutableData()));
return;
}
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
Str(fd_tensor->dtype).c_str());
@@ -77,7 +85,9 @@ FDDataType PaddleDataTypeToFD(const paddle_infer::DataType& dtype) {
fd_dtype = FDDataType::UINT8;
} else if (dtype == paddle_infer::INT8) {
fd_dtype = FDDataType::INT8;
}else {
} else if (dtype == paddle_infer::FLOAT16) {
fd_dtype = FDDataType::FP16;
} else {
FDASSERT(
false,
"Unexpected data type: %d while call CopyTensorToCpu in PaddleBackend.",

View File

@@ -166,6 +166,8 @@ void FDTensor::PrintInfo(const std::string& prefix) {
CalculateStatisInfo<int32_t>(Data(), Numel(), &mean, &max, &min);
} else if (dtype == FDDataType::INT64) {
CalculateStatisInfo<int64_t>(Data(), Numel(), &mean, &max, &min);
} else if (dtype == FDDataType::FP16) {
CalculateStatisInfo<float16>(Data(), Numel(), &mean, &max, &min);
} else {
FDASSERT(false,
"PrintInfo function doesn't support current situation, maybe you "

View File

@@ -20,33 +20,12 @@
#include <iostream>
#include <limits>
#ifdef WITH_GPU
#include <cuda.h>
#endif // WITH_GPU
#if defined(__CUDACC__) && CUDA_VERSION >= 7050
#define FD_CUDA_FP16
#include <cuda_fp16.h>
#endif
#if !defined(_WIN32)
#define FD_ALIGN(x) __attribute__((aligned(x)))
#else
#define FD_ALIGN(x) __declspec(align(x))
#endif
#define CUDA_ARCH_FP16_SUPPORTED(CUDA_ARCH) (CUDA_ARCH >= 600)
#ifdef WITH_GPU
#define HOSTDEVICE __host__ __device__
#define DEVICE __device__
#define HOST __host__
#else
#define HOSTDEVICE
#define DEVICE
#define HOST
#endif
namespace fastdeploy {
struct FD_ALIGN(2) float16 {
@@ -63,24 +42,17 @@ struct FD_ALIGN(2) float16 {
~float16() = default;
// Constructors
#ifdef FD_CUDA_FP16
HOSTDEVICE inline explicit float16(const half& h) { x = h.x; }
#endif // FD_CUDA_FP16
#ifdef FD_WITH_NATIVE_FP16
// __fp16 is a native half precision data type for arm cpu,
// float16_t is an alias for __fp16
HOSTDEVICE inline explicit float16(const float16_t& h) {
inline explicit float16(const float16_t& h) {
x = *reinterpret_cast<const uint16_t*>(&h);
}
#endif
HOSTDEVICE inline explicit float16(float val) {
#if defined(FD_CUDA_FP16)
half tmp = __float2half(val);
x = *reinterpret_cast<uint16_t*>(&tmp);
#elif defined(FD_WITH_NATIVE_FP16)
inline explicit float16(float val) {
#if defined(FD_WITH_NATIVE_FP16)
float32x4_t tmp = vld1q_dup_f32(&val);
float16_t res = vget_lane_f16(vcvt_f16_f32(tmp), 0);
x = *reinterpret_cast<uint16_t*>(&res);
@@ -109,109 +81,85 @@ struct FD_ALIGN(2) float16 {
#endif
}
HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}
template <class T>
HOSTDEVICE inline explicit float16(const T& val)
inline explicit float16(const T& val)
: x(float16(static_cast<float>(val)).x) {}
// Assignment operators
#ifdef FD_CUDA_FP16
HOSTDEVICE inline float16& operator=(const half& rhs) {
x = rhs.x;
return *this;
}
#endif
#ifdef FD_WITH_NATIVE_FP16
HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
inline float16& operator=(const float16_t& rhs) {
x = *reinterpret_cast<const uint16_t*>(&rhs);
return *this;
}
#endif
HOSTDEVICE inline float16& operator=(bool b) {
inline float16& operator=(bool b) {
x = b ? 0x3c00 : 0;
return *this;
}
HOSTDEVICE inline float16& operator=(int8_t val) {
inline float16& operator=(int8_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(uint8_t val) {
inline float16& operator=(uint8_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(int16_t val) {
inline float16& operator=(int16_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(uint16_t val) {
inline float16& operator=(uint16_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(int32_t val) {
inline float16& operator=(int32_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(uint32_t val) {
inline float16& operator=(uint32_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(int64_t val) {
inline float16& operator=(int64_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(uint64_t val) {
inline float16& operator=(uint64_t val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(float val) {
inline float16& operator=(float val) {
x = float16(val).x;
return *this;
}
HOSTDEVICE inline float16& operator=(double val) {
inline float16& operator=(double val) {
x = float16(val).x;
return *this;
}
// Conversion opertors
#ifdef FD_CUDA_FP16
HOSTDEVICE inline half to_half() const {
#ifdef CUDA_VERSION >= 9000
__half_raw h;
h.x = x;
return half(h);
#else
half h;
h.x = x;
return h;
#endif
}
#endif // FD_CUDA_FP16
#ifdef FD_WITH_NATIVE_FP16
HOSTDEVICE inline explicit operator float16_t() const {
return *reinterpret_cast<const float16_t*>(this);
}
#endif
HOSTDEVICE inline operator float() const {
#if defined(FD_CUDA_FP16) && (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300)
half tmp = *reinterpret_cast<const half*>(this);
return __half2float(tmp);
#elif defined(FD_WITH_NATIVE_FP16)
inline operator float() const {
#if defined(FD_WITH_NATIVE_FP16)
float16x4_t res = vld1_dup_f16(reinterpret_cast<const float16_t*>(this));
return vgetq_lane_f32(vcvt_f32_f16(res), 0);
@@ -240,41 +188,41 @@ struct FD_ALIGN(2) float16 {
#endif
}
HOSTDEVICE inline explicit operator bool() const { return (x & 0x7fff) != 0; }
inline explicit operator bool() const { return (x & 0x7fff) != 0; }
HOSTDEVICE inline explicit operator int8_t() const {
inline explicit operator int8_t() const {
return static_cast<int8_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint8_t() const {
inline explicit operator uint8_t() const {
return static_cast<uint8_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator int16_t() const {
inline explicit operator int16_t() const {
return static_cast<int16_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint16_t() const {
inline explicit operator uint16_t() const {
return static_cast<uint16_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator int32_t() const {
inline explicit operator int32_t() const {
return static_cast<int32_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint32_t() const {
inline explicit operator uint32_t() const {
return static_cast<uint32_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator int64_t() const {
inline explicit operator int64_t() const {
return static_cast<int64_t>(static_cast<float>(*this));
}
HOSTDEVICE inline explicit operator uint64_t() const {
inline explicit operator uint64_t() const {
return static_cast<uint64_t>(static_cast<float>(*this));
}
HOSTDEVICE inline operator double() const {
inline operator double() const {
return static_cast<double>(static_cast<float>(*this));
}
@@ -309,123 +257,8 @@ struct FD_ALIGN(2) float16 {
static constexpr int32_t minD = minC - subC - 1;
};
// Arithmetic operators for float16 on GPU
#if defined(FD_CUDA_FP16)
HOSTDEVICE inline float16 operator+(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hadd(a.to_half(), b.to_half()));
#else
return float16(static_cast<float>(a) + static_cast<float>(b));
#endif
}
HOSTDEVICE inline float16 operator-(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hsub(a.to_half(), b.to_half()));
#else
return float16(static_cast<float>(a) - static_cast<float>(b));
#endif
}
HOSTDEVICE inline float16 operator*(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hmul(a.to_half(), b.to_half()));
#else
return float16(static_cast<float>(a) * static_cast<float>(b));
#endif
}
HOSTDEVICE inline float16 operator/(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
// TODO(kexinzhao): check which cuda version starts to support __hdiv
float num = __half2float(a.to_half());
float denom = __half2float(b.to_half());
return float16(num / denom);
#else
return float16(static_cast<float>(a) / static_cast<float>(b));
#endif
}
HOSTDEVICE inline float16 operator-(const float16& a) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return float16(__hneg(a.to_half()));
#else
float16 res;
res.x = a.x ^ 0x8000;
return res;
#endif
}
HOSTDEVICE inline float16& operator+=(float16& a, const float16& b) { // NOLINT
a = a + b;
return a;
}
HOSTDEVICE inline float16& operator-=(float16& a, const float16& b) { // NOLINT
a = a - b;
return a;
}
HOSTDEVICE inline float16& operator*=(float16& a, const float16& b) { // NOLINT
a = a * b;
return a;
}
HOSTDEVICE inline float16& operator/=(float16& a, const float16& b) { // NOLINT
a = a / b;
return a;
}
HOSTDEVICE inline bool operator==(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __heq(a.to_half(), b.to_half());
#else
return static_cast<float>(a) == static_cast<float>(b);
#endif
}
HOSTDEVICE inline bool operator!=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hne(a.to_half(), b.to_half());
#else
return static_cast<float>(a) != static_cast<float>(b);
#endif
}
HOSTDEVICE inline bool operator<(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hlt(a.to_half(), b.to_half());
#else
return static_cast<float>(a) < static_cast<float>(b);
#endif
}
HOSTDEVICE inline bool operator<=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hle(a.to_half(), b.to_half());
#else
return static_cast<float>(a) <= static_cast<float>(b);
#endif
}
HOSTDEVICE inline bool operator>(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hgt(a.to_half(), b.to_half());
#else
return static_cast<float>(a) > static_cast<float>(b);
#endif
}
HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hge(a.to_half(), b.to_half());
#else
return static_cast<float>(a) >= static_cast<float>(b);
#endif
}
#elif defined(FD_WITH_NATIVE_FP16)
// Arithmetic operators for float16 on ARMv8.2-A CPU
#if defined(FD_WITH_NATIVE_FP16)
inline float16 operator+(const float16& a, const float16& b) {
float16 res;
asm volatile(
@@ -673,41 +506,28 @@ inline bool operator>=(const float16& a, const float16& b) {
}
#endif
HOSTDEVICE inline float16 raw_uint16_to_float16(uint16_t a) {
float16 res;
res.x = a;
return res;
}
inline float16 raw_uint16_to_float16(uint16_t a) {
float16 res;
res.x = a;
return res;
}
HOSTDEVICE inline bool(isnan)(const float16& a) {
#if defined(FD_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
return __hisnan(a.to_half());
#else
return (a.x & 0x7fff) > 0x7c00;
#endif
}
inline bool(isnan)(const float16& a) { return (a.x & 0x7fff) > 0x7c00; }
HOSTDEVICE inline bool(isinf)(const float16& a) {
return (a.x & 0x7fff) == 0x7c00;
}
inline bool(isinf)(const float16& a) { return (a.x & 0x7fff) == 0x7c00; }
HOSTDEVICE inline bool(isfinite)(const float16& a) {
return !((isnan)(a)) && !((isinf)(a));
}
inline bool(isfinite)(const float16& a) {
return !((isnan)(a)) && !((isinf)(a));
}
HOSTDEVICE inline float16(abs)(const float16& a) {
#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530)
return float16(::fabs(static_cast<float>(a)));
#else
inline float16(abs)(const float16& a) {
return float16(std::abs(static_cast<float>(a)));
#endif
}
inline std::ostream& operator<<(std::ostream& os, const float16& a) {
os << static_cast<float>(a);
return os;
}
}
inline std::ostream& operator<<(std::ostream& os, const float16& a) {
os << static_cast<float>(a);
return os;
}
} // namespace fastdeploy
namespace std {
@@ -772,36 +592,34 @@ struct numeric_limits<fastdeploy::float16> {
static const bool traps = true;
static const bool tinyness_before = false;
HOSTDEVICE static fastdeploy::float16(min)() {
static fastdeploy::float16(min)() {
return fastdeploy::raw_uint16_to_float16(0x400);
}
HOSTDEVICE static fastdeploy::float16 lowest() {
static fastdeploy::float16 lowest() {
return fastdeploy::raw_uint16_to_float16(0xfbff);
}
HOSTDEVICE static fastdeploy::float16(max)() {
static fastdeploy::float16(max)() {
return fastdeploy::raw_uint16_to_float16(0x7bff);
}
HOSTDEVICE static fastdeploy::float16 epsilon() {
static fastdeploy::float16 epsilon() {
return fastdeploy::raw_uint16_to_float16(0x0800);
}
HOSTDEVICE static fastdeploy::float16 round_error() {
return fastdeploy::float16(0.5);
}
HOSTDEVICE static fastdeploy::float16 infinity() {
static fastdeploy::float16 round_error() { return fastdeploy::float16(0.5); }
static fastdeploy::float16 infinity() {
return fastdeploy::raw_uint16_to_float16(0x7c00);
}
HOSTDEVICE static fastdeploy::float16 quiet_NaN() {
static fastdeploy::float16 quiet_NaN() {
return fastdeploy::raw_uint16_to_float16(0x7e00);
}
HOSTDEVICE static fastdeploy::float16 signaling_NaN() {
static fastdeploy::float16 signaling_NaN() {
return fastdeploy::raw_uint16_to_float16(0x7e00);
}
HOSTDEVICE static fastdeploy::float16 denorm_min() {
static fastdeploy::float16 denorm_min() {
return fastdeploy::raw_uint16_to_float16(0x1);
}
};
HOSTDEVICE inline fastdeploy::float16 abs(const fastdeploy::float16& a) {
inline fastdeploy::float16 abs(const fastdeploy::float16& a) {
return fastdeploy::abs(a);
}

View File

@@ -33,6 +33,8 @@ pybind11::dtype FDDataTypeToNumpyDataType(const FDDataType& fd_dtype) {
dt = pybind11::dtype::of<double>();
} else if (fd_dtype == FDDataType::UINT8) {
dt = pybind11::dtype::of<uint8_t>();
} else if (fd_dtype == FDDataType::FP16) {
dt = pybind11::dtype::of<float16>();
} else {
FDASSERT(false, "The function doesn't support data type of %s.",
Str(fd_dtype).c_str());
@@ -51,10 +53,12 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype) {
return FDDataType::FP64;
} else if (np_dtype.is(pybind11::dtype::of<uint8_t>())) {
return FDDataType::UINT8;
} else if (np_dtype.is(pybind11::dtype::of<float16>())) {
return FDDataType::FP16;
}
FDASSERT(false,
"NumpyDataTypeToFDDataType() only support "
"int32/int64/float32/float64 now.");
"int32/int64/float32/float64/float16 now.");
return FDDataType::FP32;
}

View File

@@ -30,6 +30,8 @@
#include "fastdeploy/text.h"
#endif
#include "fastdeploy/core/float16.h"
namespace fastdeploy {
void BindBackend(pybind11::module&);

View File

@@ -135,6 +135,7 @@ void BindRuntime(pybind11::module& m) {
.value("INT16", FDDataType::INT16)
.value("INT32", FDDataType::INT32)
.value("INT64", FDDataType::INT64)
.value("FP16", FDDataType::FP16)
.value("FP32", FDDataType::FP32)
.value("FP64", FDDataType::FP64)
.value("UINT8", FDDataType::UINT8);