mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
add float16
This commit is contained in:
@@ -100,3 +100,32 @@ std::vector<pybind11::array> PyBackendInfer(
|
||||
}
|
||||
|
||||
} // namespace fastdeploy
|
||||
|
||||
namespace pybind11 {
|
||||
namespace detail {
|
||||
|
||||
// Note: use same enum number of float16 in numpy.
|
||||
// import numpy as np
|
||||
// print np.dtype(np.float16).num # 23
|
||||
constexpr int NPY_FLOAT16_ = 23;
|
||||
|
||||
// Note: Since float16 is not a builtin type in C++, we register
|
||||
// fastdeploy::float16 as numpy.float16.
|
||||
// Ref: https://github.com/pybind/pybind11/issues/1776
|
||||
template <>
|
||||
struct npy_format_descriptor<fastdeploy::float16> {
|
||||
static pybind11::dtype dtype() {
|
||||
handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_);
|
||||
return reinterpret_borrow<pybind11::dtype>(ptr);
|
||||
}
|
||||
static std::string format() {
|
||||
// Note: "e" represents float16.
|
||||
// Details at:
|
||||
// https://docs.python.org/3/library/struct.html#format-characters.
|
||||
return "e";
|
||||
}
|
||||
static constexpr auto name = _("float16");
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace pybind11
|
||||
|
Reference in New Issue
Block a user