diff --git a/fastdeploy/pybind/fastdeploy_runtime.cc b/fastdeploy/pybind/fastdeploy_runtime.cc index ad4715a0f..cf1d0f029 100644 --- a/fastdeploy/pybind/fastdeploy_runtime.cc +++ b/fastdeploy/pybind/fastdeploy_runtime.cc @@ -67,6 +67,12 @@ void BindRuntime(pybind11::module& m) { pybind11::class_(m, "Runtime") .def(pybind11::init()) .def("init", &Runtime::Init) + .def("infer", + [](Runtime& self, std::vector& inputs) { + std::vector outputs(self.NumOutputs()); + self.Infer(inputs, &outputs); + return outputs; + }) .def("infer", [](Runtime& self, std::map& data) { std::vector inputs(data.size()); @@ -132,6 +138,32 @@ void BindRuntime(pybind11::module& m) { .value("FP64", FDDataType::FP64) .value("UINT8", FDDataType::UINT8); + pybind11::class_(m, "FDTensor", pybind11::buffer_protocol()) + .def(pybind11::init()) + .def("cpu_data", + [](FDTensor& self) { + auto ptr = self.CpuData(); + auto numel = self.Numel(); + auto dtype = FDDataTypeToNumpyDataType(self.dtype); + auto base = pybind11::array(dtype, self.shape); + return pybind11::array(dtype, self.shape, ptr, base); + }) + .def("resize", static_cast(&FDTensor::Resize)) + .def("resize", + static_cast&)>( + &FDTensor::Resize)) + .def( + "resize", + [](FDTensor& self, const std::vector& shape, + const FDDataType& dtype, const std::string& name, + const Device& device) { self.Resize(shape, dtype, name, device); }) + .def("numel", &FDTensor::Numel) + .def("nbytes", &FDTensor::Nbytes) + .def_readwrite("name", &FDTensor::name) + .def_readonly("shape", &FDTensor::shape) + .def_readonly("dtype", &FDTensor::dtype) + .def_readonly("device", &FDTensor::device); + m.def("get_available_backends", []() { return GetAvailableBackends(); }); } diff --git a/fastdeploy/pybind/main.cc.in b/fastdeploy/pybind/main.cc.in index 5aaac049c..e233fdad7 100644 --- a/fastdeploy/pybind/main.cc.in +++ b/fastdeploy/pybind/main.cc.in @@ -73,6 +73,13 @@ void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor, } } +void PyArrayToTensorList(std::vector& pyarrays, std::vector* tensors, + bool share_buffer) { + for(auto i = 0; i < pyarrays.size(); ++i) { + PyArrayToTensor(pyarrays[i], &(*tensors)[i], share_buffer); + } +} + pybind11::array TensorToPyArray(const FDTensor& tensor) { auto numpy_dtype = FDDataTypeToNumpyDataType(tensor.dtype); auto out = pybind11::array(numpy_dtype, tensor.shape); diff --git a/fastdeploy/pybind/main.h b/fastdeploy/pybind/main.h index 6c19edb99..6eb3857dd 100644 --- a/fastdeploy/pybind/main.h +++ b/fastdeploy/pybind/main.h @@ -42,6 +42,9 @@ FDDataType NumpyDataTypeToFDDataType(const pybind11::dtype& np_dtype); void PyArrayToTensor(pybind11::array& pyarray, FDTensor* tensor, bool share_buffer = false); +void PyArrayToTensorList(std::vector& pyarray, + std::vector* tensor, + bool share_buffer = false); pybind11::array TensorToPyArray(const FDTensor& tensor); #ifdef ENABLE_VISION diff --git a/fastdeploy/vision/detection/contrib/yolov5.cc b/fastdeploy/vision/detection/contrib/yolov5.cc index 0655b7f3c..b582bf299 100644 --- a/fastdeploy/vision/detection/contrib/yolov5.cc +++ b/fastdeploy/vision/detection/contrib/yolov5.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "fastdeploy/vision/detection/contrib/yolov5.h" + #include "fastdeploy/utils/perf.h" #include "fastdeploy/vision/utils/utils.h" @@ -74,14 +75,14 @@ YOLOv5::YOLOv5(const std::string& model_file, const std::string& params_file, bool YOLOv5::Initialize() { // parameters for preprocess - size = {640, 640}; - padding_value = {114.0, 114.0, 114.0}; - is_mini_pad = false; - is_no_pad = false; - is_scale_up = false; - stride = 32; - max_wh = 7680.0; - multi_label = true; + size_ = {640, 640}; + padding_value_ = {114.0, 114.0, 114.0}; + is_mini_pad_ = false; + is_no_pad_ = false; + is_scale_up_ = false; + stride_ = 32; + max_wh_ = 7680.0; + multi_label_ = true; if (!InitRuntime()) { FDERROR << "Failed to initialize fastdeploy backend." << std::endl; @@ -90,23 +91,34 @@ bool YOLOv5::Initialize() { // Check if the input shape is dynamic after Runtime already initialized, // Note that, We need to force is_mini_pad 'false' to keep static // shape after padding (LetterBox) when the is_dynamic_shape is 'false'. - is_dynamic_input_ = false; - auto shape = InputInfoOfRuntime(0).shape; - for (int i = 0; i < shape.size(); ++i) { - // if height or width is dynamic - if (i >= 2 && shape[i] <= 0) { - is_dynamic_input_ = true; - break; - } - } - if (!is_dynamic_input_) { - is_mini_pad = false; - } + // TODO(qiuyanjun): remove + // is_dynamic_input_ = false; + // auto shape = InputInfoOfRuntime(0).shape; + // for (int i = 0; i < shape.size(); ++i) { + // // if height or width is dynamic + // if (i >= 2 && shape[i] <= 0) { + // is_dynamic_input_ = true; + // break; + // } + // } + // if (!is_dynamic_input_) { + // is_mini_pad_ = false; + // } return true; } bool YOLOv5::Preprocess(Mat* mat, FDTensor* output, - std::map>* im_info) { + std::map>* im_info, + const std::vector& size, + const std::vector padding_value, + bool is_mini_pad, bool is_no_pad, bool is_scale_up, + int stride, float max_wh, bool multi_label) { + // Record the shape of image and the shape of preprocessed image + (*im_info)["input_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + (*im_info)["output_shape"] = {static_cast(mat->Height()), + static_cast(mat->Width())}; + // process after image load double ratio = (size[0] * 1.0) / std::max(static_cast(mat->Height()), static_cast(mat->Width())); @@ -145,9 +157,11 @@ bool YOLOv5::Preprocess(Mat* mat, FDTensor* output, } bool YOLOv5::Postprocess( - FDTensor& infer_result, DetectionResult* result, + std::vector& infer_results, DetectionResult* result, const std::map>& im_info, - float conf_threshold, float nms_iou_threshold, bool multi_label) { + float conf_threshold, float nms_iou_threshold, bool multi_label, + float max_wh) { + auto& infer_result = infer_results[0]; FDASSERT(infer_result.shape[0] == 1, "Only support batch =1 now."); result->Clear(); if (multi_label) { @@ -251,13 +265,9 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold, std::map> im_info; - // Record the shape of image and the shape of preprocessed image - im_info["input_shape"] = {static_cast(mat.Height()), - static_cast(mat.Width())}; - im_info["output_shape"] = {static_cast(mat.Height()), - static_cast(mat.Width())}; - - if (!Preprocess(&mat, &input_tensors[0], &im_info)) { + if (!Preprocess(&mat, &input_tensors[0], &im_info, size_, padding_value_, + is_mini_pad_, is_no_pad_, is_scale_up_, stride_, max_wh_, + multi_label_)) { FDERROR << "Failed to preprocess input image." << std::endl; return false; } @@ -278,8 +288,8 @@ bool YOLOv5::Predict(cv::Mat* im, DetectionResult* result, float conf_threshold, TIMERECORD_START(2) #endif - if (!Postprocess(output_tensors[0], result, im_info, conf_threshold, - nms_iou_threshold, multi_label)) { + if (!Postprocess(output_tensors, result, im_info, conf_threshold, + nms_iou_threshold, multi_label_)) { FDERROR << "Failed to post process." << std::endl; return false; } diff --git a/fastdeploy/vision/detection/contrib/yolov5.h b/fastdeploy/vision/detection/contrib/yolov5.h index 68c910d23..1d2acb9ae 100644 --- a/fastdeploy/vision/detection/contrib/yolov5.h +++ b/fastdeploy/vision/detection/contrib/yolov5.h @@ -41,38 +41,18 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel { float conf_threshold = 0.25, float nms_iou_threshold = 0.5); - // 以下为模型在预测时的一些参数,基本是前后处理所需 - // 用户在创建模型后,可根据模型的要求,以及自己的需求 - // 对参数进行修改 - // tuple of (width, height) - std::vector size; - // padding value, size should be same with Channels - std::vector padding_value; - // only pad to the minimum rectange which height and width is times of stride - bool is_mini_pad; - // while is_mini_pad = false and is_no_pad = true, will resize the image to - // the set size - bool is_no_pad; - // if is_scale_up is false, the input image only can be zoom out, the maximum - // resize scale cannot exceed 1.0 - bool is_scale_up; - // padding stride, for is_mini_pad - int stride; - // for offseting the boxes by classes when using NMS - float max_wh; - // for different strategies to get boxes when postprocessing - bool multi_label; - - private: - // 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作 - bool Initialize(); - // 输入图像预处理操作 // Mat为FastDeploy定义的数据结构 // FDTensor为预处理后的Tensor数据,传给后端进行推理 // im_info为预处理过程保存的数据,在后处理中需要用到 - bool Preprocess(Mat* mat, FDTensor* outputs, - std::map>* im_info); + static bool Preprocess(Mat* mat, FDTensor* output, + std::map>* im_info, + const std::vector& size = {640, 640}, + const std::vector padding_value = {114.0, 114.0, + 114.0}, + bool is_mini_pad = false, bool is_no_pad = false, + bool is_scale_up = false, int stride = 32, + float max_wh = 7680.0, bool multi_label = true); // 后端推理结果后处理,输出给用户 // infer_result 为后端推理后的输出Tensor @@ -81,17 +61,45 @@ class FASTDEPLOY_DECL YOLOv5 : public FastDeployModel { // conf_threshold 后处理时过滤box的置信度阈值 // nms_iou_threshold 后处理时NMS设定的iou阈值 // multi_label 后处理时box选取是否采用多标签方式 - bool Postprocess(FDTensor& infer_result, DetectionResult* result, - const std::map>& im_info, - float conf_threshold, float nms_iou_threshold, - bool multi_label); + static bool Postprocess( + std::vector& infer_results, DetectionResult* result, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold, bool multi_label, + float max_wh = 7680.0); + + // 以下为模型在预测时的一些参数,基本是前后处理所需 + // 用户在创建模型后,可根据模型的要求,以及自己的需求 + // 对参数进行修改 + // tuple of (width, height) + std::vector size_; + // padding value, size should be same with Channels + std::vector padding_value_; + // only pad to the minimum rectange which height and width is times of stride + bool is_mini_pad_; + // while is_mini_pad = false and is_no_pad = true, will resize the image to + // the set size + bool is_no_pad_; + // if is_scale_up is false, the input image only can be zoom out, the maximum + // resize scale cannot exceed 1.0 + bool is_scale_up_; + // padding stride, for is_mini_pad + int stride_; + // for offseting the boxes by classes when using NMS + float max_wh_; + // for different strategies to get boxes when postprocessing + bool multi_label_; + + private: + // 初始化函数,包括初始化后端,以及其它模型推理需要涉及的操作 + bool Initialize(); // 查看输入是否为动态维度的 不建议直接使用 不同模型的逻辑可能不一致 bool IsDynamicInput() const { return is_dynamic_input_; } - void LetterBox(Mat* mat, std::vector size, std::vector color, - bool _auto, bool scale_fill = false, bool scale_up = true, - int stride = 32); + static void LetterBox(Mat* mat, std::vector size, + std::vector color, bool _auto, + bool scale_fill = false, bool scale_up = true, + int stride = 32); // whether to inference with dynamic shape (e.g ONNX export with dynamic shape // or not.) diff --git a/fastdeploy/vision/detection/contrib/yolov5_pybind.cc b/fastdeploy/vision/detection/contrib/yolov5_pybind.cc index 65ba538b8..24d318a83 100644 --- a/fastdeploy/vision/detection/contrib/yolov5_pybind.cc +++ b/fastdeploy/vision/detection/contrib/yolov5_pybind.cc @@ -26,13 +26,43 @@ void BindYOLOv5(pybind11::module& m) { self.Predict(&mat, &res, conf_threshold, nms_iou_threshold); return res; }) - .def_readwrite("size", &vision::detection::YOLOv5::size) - .def_readwrite("padding_value", &vision::detection::YOLOv5::padding_value) - .def_readwrite("is_mini_pad", &vision::detection::YOLOv5::is_mini_pad) - .def_readwrite("is_no_pad", &vision::detection::YOLOv5::is_no_pad) - .def_readwrite("is_scale_up", &vision::detection::YOLOv5::is_scale_up) - .def_readwrite("stride", &vision::detection::YOLOv5::stride) - .def_readwrite("max_wh", &vision::detection::YOLOv5::max_wh) - .def_readwrite("multi_label", &vision::detection::YOLOv5::multi_label); + .def_static("preprocess", + [](pybind11::array& data, const std::vector& size, + const std::vector padding_value, bool is_mini_pad, + bool is_no_pad, bool is_scale_up, int stride, float max_wh, + bool multi_label) { + auto mat = PyArrayToCvMat(data); + fastdeploy::vision::Mat fd_mat(mat); + FDTensor output; + std::map> im_info; + vision::detection::YOLOv5::Preprocess( + &fd_mat, &output, &im_info, size, padding_value, + is_mini_pad, is_no_pad, is_scale_up, stride, max_wh, + multi_label); + return make_pair(TensorToPyArray(output), im_info); + }) + .def_static( + "postprocess", + [](std::vector infer_results, + const std::map>& im_info, + float conf_threshold, float nms_iou_threshold, bool multi_label, + float max_wh) { + std::vector fd_infer_results(infer_results.size()); + PyArrayToTensorList(infer_results, &fd_infer_results, true); + vision::DetectionResult result; + vision::detection::YOLOv5::Postprocess( + fd_infer_results, &result, im_info, conf_threshold, + nms_iou_threshold, multi_label, max_wh); + return result; + }) + .def_readwrite("size", &vision::detection::YOLOv5::size_) + .def_readwrite("padding_value", + &vision::detection::YOLOv5::padding_value_) + .def_readwrite("is_mini_pad", &vision::detection::YOLOv5::is_mini_pad_) + .def_readwrite("is_no_pad", &vision::detection::YOLOv5::is_no_pad_) + .def_readwrite("is_scale_up", &vision::detection::YOLOv5::is_scale_up_) + .def_readwrite("stride", &vision::detection::YOLOv5::stride_) + .def_readwrite("max_wh", &vision::detection::YOLOv5::max_wh_) + .def_readwrite("multi_label", &vision::detection::YOLOv5::multi_label_); } } // namespace fastdeploy diff --git a/python/fastdeploy/__init__.py b/python/fastdeploy/__init__.py index b27c6e5c8..0f45f5778 100644 --- a/python/fastdeploy/__init__.py +++ b/python/fastdeploy/__init__.py @@ -17,9 +17,10 @@ import os import sys from .c_lib_wrap import (Frontend, Backend, FDDataType, TensorInfo, Device, - is_built_with_gpu, is_built_with_ort, + FDTensor, is_built_with_gpu, is_built_with_ort, is_built_with_paddle, is_built_with_trt, get_default_cuda_directory) + from .runtime import Runtime, RuntimeOption from .model import FastDeployModel from . import c_lib_wrap as C diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py index f19306a2c..42d5ac62e 100644 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -23,7 +23,8 @@ class Runtime: runtime_option._option), "Initialize Runtime Failed!" def infer(self, data): - assert isinstance(data, dict), "The input data should be type of dict." + assert isinstance(data, dict) or isinstance( + data, list), "The input data should be type of dict or list." return self._runtime.infer(data) def num_inputs(self): diff --git a/python/fastdeploy/vision/detection/contrib/yolov5.py b/python/fastdeploy/vision/detection/contrib/yolov5.py index 5a7711dd3..51d505988 100644 --- a/python/fastdeploy/vision/detection/contrib/yolov5.py +++ b/python/fastdeploy/vision/detection/contrib/yolov5.py @@ -37,6 +37,31 @@ class YOLOv5(FastDeployModel): return self._model.predict(input_image, conf_threshold, nms_iou_threshold) + @staticmethod + def preprocess(input_image, + size=[640, 640], + padding_value=[114.0, 114.0, 114.0], + is_mini_pad=False, + is_no_pad=False, + is_scale_up=False, + stride=32, + max_wh=7680.0, + multi_label=True): + return C.vision.detection.YOLOv5.preprocess( + input_image, size, padding_value, is_mini_pad, is_no_pad, + is_scale_up, stride, max_wh, multi_label) + + @staticmethod + def postprocess(infer_result, + im_info, + conf_threshold=0.25, + nms_iou_threshold=0.5, + multi_label=True, + max_wh=7680.0): + return C.vision.detection.YOLOv5.postprocess( + infer_result, im_info, conf_threshold, nms_iou_threshold, + multi_label, max_wh) + # 一些跟YOLOv5模型有关的属性封装 # 多数是预处理相关,可通过修改如model.size = [1280, 1280]改变预处理时resize的大小(前提是模型支持) @property