add float16

This commit is contained in:
zhoushunjie
2022-10-05 02:44:39 +00:00
parent 5147c1c750
commit 6e4319348b

View File

@@ -100,3 +100,32 @@ std::vector<pybind11::array> PyBackendInfer(
} }
} // namespace fastdeploy } // namespace fastdeploy
namespace pybind11 {
namespace detail {
// Note: use same enum number of float16 in numpy.
// import numpy as np
// print np.dtype(np.float16).num # 23
constexpr int NPY_FLOAT16_ = 23;
// Note: Since float16 is not a builtin type in C++, we register
// fastdeploy::float16 as numpy.float16.
// Ref: https://github.com/pybind/pybind11/issues/1776
template <>
struct npy_format_descriptor<fastdeploy::float16> {
static pybind11::dtype dtype() {
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_);
return reinterpret_borrow<pybind11::dtype>(ptr);
}
static std::string format() {
// Note: "e" represents float16.
// Details at:
// https://docs.python.org/3/library/struct.html#format-characters.
return "e";
}
static constexpr auto name = _("float16");
};
} // namespace detail
} // namespace pybind11