From 6e4319348b0744588e521c302f5ab8dbd74bf9aa Mon Sep 17 00:00:00 2001 From: zhoushunjie Date: Wed, 5 Oct 2022 02:44:39 +0000 Subject: [PATCH] add float16 --- fastdeploy/pybind/main.h | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/fastdeploy/pybind/main.h b/fastdeploy/pybind/main.h index e18c6bb22..b31c49f07 100644 --- a/fastdeploy/pybind/main.h +++ b/fastdeploy/pybind/main.h @@ -100,3 +100,32 @@ std::vector PyBackendInfer( } } // namespace fastdeploy + +namespace pybind11 { +namespace detail { + +// Note: use same enum number of float16 in numpy. +// import numpy as np +// print np.dtype(np.float16).num # 23 +constexpr int NPY_FLOAT16_ = 23; + +// Note: Since float16 is not a builtin type in C++, we register +// fastdeploy::float16 as numpy.float16. +// Ref: https://github.com/pybind/pybind11/issues/1776 +template <> +struct npy_format_descriptor { + static pybind11::dtype dtype() { + handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_); + return reinterpret_borrow(ptr); + } + static std::string format() { + // Note: "e" represents float16. + // Details at: + // https://docs.python.org/3/library/struct.html#format-characters. + return "e"; + } + static constexpr auto name = _("float16"); +}; + +} // namespace detail +} // namespace pybind11