mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
add float16
This commit is contained in:
@@ -100,3 +100,32 @@ std::vector<pybind11::array> PyBackendInfer(
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
|
||||||
|
namespace pybind11 {
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
// Note: use same enum number of float16 in numpy.
|
||||||
|
// import numpy as np
|
||||||
|
// print np.dtype(np.float16).num # 23
|
||||||
|
constexpr int NPY_FLOAT16_ = 23;
|
||||||
|
|
||||||
|
// Note: Since float16 is not a builtin type in C++, we register
|
||||||
|
// fastdeploy::float16 as numpy.float16.
|
||||||
|
// Ref: https://github.com/pybind/pybind11/issues/1776
|
||||||
|
template <>
|
||||||
|
struct npy_format_descriptor<fastdeploy::float16> {
|
||||||
|
static pybind11::dtype dtype() {
|
||||||
|
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_);
|
||||||
|
return reinterpret_borrow<pybind11::dtype>(ptr);
|
||||||
|
}
|
||||||
|
static std::string format() {
|
||||||
|
// Note: "e" represents float16.
|
||||||
|
// Details at:
|
||||||
|
// https://docs.python.org/3/library/struct.html#format-characters.
|
||||||
|
return "e";
|
||||||
|
}
|
||||||
|
static constexpr auto name = _("float16");
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
} // namespace pybind11
|
||||||
|
Reference in New Issue
Block a user