diff --git a/.new_docs/api.md b/.new_docs/api.md new file mode 100644 index 000000000..81c3c6ed6 --- /dev/null +++ b/.new_docs/api.md @@ -0,0 +1,4 @@ +# API说明 + +- [Python API](./python_apis/index.rst) +- [C++ API](https://paddlepaddle.github.io/FastDeploy/) diff --git a/.new_docs/index.rst b/.new_docs/index.rst index eaea5eb73..b48abf69b 100644 --- a/.new_docs/index.rst +++ b/.new_docs/index.rst @@ -14,3 +14,4 @@ FastDeploy build_and_install/index quick_start/index + api.md diff --git a/.new_docs/python_apis/image_classification.md b/.new_docs/python_apis/image_classification.md new file mode 100644 index 000000000..b4a88cfed --- /dev/null +++ b/.new_docs/python_apis/image_classification.md @@ -0,0 +1,9 @@ +# 图像分类模型部署 + +## fastdeploy.vision.classification.PaddleClasModel + +```{eval-rst} +.. autoclass:: fastdeploy.vision.classification.PaddleClasModel + :members: + :inherited-members: +``` diff --git a/.new_docs/python_apis/index.rst b/.new_docs/python_apis/index.rst new file mode 100644 index 000000000..4f5e93c3f --- /dev/null +++ b/.new_docs/python_apis/index.rst @@ -0,0 +1,13 @@ +Python API +======================================= + +FastDeploy支持通过Python编程语言进行部署 + +.. toctree:: + :caption: Python API + :maxdepth: 3 + :titlesonly: + + image_classification.md + object_detection.md + runtime.md diff --git a/.new_docs/references/api_reference.md b/.new_docs/python_apis/object_detection.md similarity index 97% rename from .new_docs/references/api_reference.md rename to .new_docs/python_apis/object_detection.md index 24fc0741f..7c7a93859 100644 --- a/.new_docs/references/api_reference.md +++ b/.new_docs/python_apis/object_detection.md @@ -1,4 +1,4 @@ -# Object Detection API Reference +# 目标检测模型部署 ## fastdeploy.vision.detection.PPYOLOE diff --git a/.new_docs/python_apis/runtime.md b/.new_docs/python_apis/runtime.md new file mode 100644 index 000000000..2202b70e3 --- /dev/null +++ b/.new_docs/python_apis/runtime.md @@ -0,0 +1,19 @@ +# Runtime模块使用 + +FastDeploy Runtime模块可单独使用,通过同样的代码,可快速完成Paddle/ONNX模型在不同硬件,后端上的推理加速部署。 + +## fastdeploy.RuntimeOption + +```{eval-rst} +.. autoclass:: fastdeploy.RuntimeOption + :members: + :inherited-members: +``` + +## fastdeploy.Runtime + +```{eval-rst} +.. autoclass:: fastdeploy.Runtime + :members: + :inherited-members: +``` diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 0c62199f1..9349c3d59 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -0.2.1 +0.3.0rc diff --git a/fastdeploy/libs/__init__.py b/fastdeploy/libs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py index 0e11f5f5e..17cb8fe3a 100644 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -13,27 +13,50 @@ # limitations under the License. from __future__ import absolute_import import logging +from . import ModelFormat from . import c_lib_wrap as C class Runtime: + """FastDeploy Runtime object. + """ + def __init__(self, runtime_option): + """Initialize a FastDeploy Runtime object. + + :param runtime_option: (fastdeploy.RuntimeOption)Options for FastDeploy Runtime + """ + self._runtime = C.Runtime() assert self._runtime.init( runtime_option._option), "Initialize Runtime Failed!" def infer(self, data): + """Inference with input data. + + :param data: (dict[str : numpy.ndarray])The input data dict, key value must keep same with the loaded model + :return list of numpy.ndarray + """ assert isinstance(data, dict) or isinstance( data, list), "The input data should be type of dict or list." return self._runtime.infer(data) def num_inputs(self): + """Get number of inputs of the loaded model. + """ return self._runtime.num_inputs() def num_outputs(self): + """Get number of outputs of the loaded model. + """ return self._runtime.num_outputs() def get_input_info(self, index): + """Get input information of the loaded model. + + :param index: (int)Index of the input + :return fastdeploy.TensorInfo + """ assert isinstance( index, int), "The input parameter index should be type of int." assert index < self.num_inputs( @@ -42,6 +65,11 @@ class Runtime: return self._runtime.get_input_info(index) def get_output_info(self, index): + """Get output information of the loaded model. + + :param index: (int)Index of the output + :return fastdeploy.TensorInfo + """ assert isinstance( index, int), "The input parameter index should be type of int." assert index < self.num_outputs( @@ -51,59 +79,102 @@ class Runtime: class RuntimeOption: + """Options for FastDeploy Runtime. + """ + def __init__(self): self._option = C.RuntimeOption() def set_model_path(self, model_path, params_path="", - model_format=C.ModelFormat.PADDLE): + model_format=ModelFormat.PADDLE): + """Set path of model file and parameters file + + :param model_path: (str)Path of model file + :param params_path: (str)Path of parameters file + :param model_format: (ModelFormat)Format of model, support ModelFormat.PADDLE/ModelFormat.ONNX + """ return self._option.set_model_path(model_path, params_path, model_format) def use_gpu(self, device_id=0): + """Inference with Nvidia GPU + + :param device_id: (int)The index of GPU will be used for inference, default 0 + """ return self._option.use_gpu(device_id) def use_cpu(self): + """Inference with CPU + """ return self._option.use_cpu() def set_cpu_thread_num(self, thread_num=-1): + """Set number of threads if inference with CPU + + :param thread_num: (int)Number of threads, if not positive, means the number of threads is decided by the backend, default -1 + """ return self._option.set_cpu_thread_num(thread_num) def use_paddle_backend(self): + """Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU. + """ return self._option.use_paddle_backend() def use_ort_backend(self): + """Use ONNX Runtime backend, support inference Paddle/ONNX model on CPU/Nvidia GPU. + """ return self._option.use_ort_backend() def use_trt_backend(self): + """Use TensorRT backend, support inference Paddle/ONNX model on Nvidia GPU. + """ return self._option.use_trt_backend() def use_openvino_backend(self): + """Use OpenVINO backend, support inference Paddle/ONNX model on CPU. + """ return self._option.use_openvino_backend() def use_lite_backend(self): + """Use Paddle Lite backend, support inference Paddle model on ARM CPU. + """ return self._option.use_lite_backend() - def set_paddle_mkldnn(self, pd_mkldnn=True): - return self._option.set_paddle_mkldnn(pd_mkldnn) + def set_paddle_mkldnn(self, use_mkldnn=True): + """Enable/Disable MKLDNN while using Paddle Inference backend, mkldnn is enabled by default. + """ + return self._option.set_paddle_mkldnn(use_mkldnn) def enable_paddle_log_info(self): + """Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default. + """ return self._option.enable_paddle_log_info() def disable_paddle_log_info(self): + """Disable print out the debug log information while using Paddle Inference backend, the log information is disabled by default. + """ return self._option.disable_paddle_log_info() def set_paddle_mkldnn_cache_size(self, cache_size): + """Set size of shape cache while using Paddle Inference backend with MKLDNN enabled, default will cache all the dynamic shape. + """ return self._option.set_paddle_mkldnn_cache_size(cache_size) def enable_lite_fp16(self): + """Enable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default. + """ return self._option.enable_lite_fp16() def disable_lite_fp16(self): + """Disable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default. + """ return self._option.disable_lite_fp16() def set_lite_power_mode(self, mode): + """Set POWER mode while using Paddle Lite backend on ARM CPU. + """ return self._option.set_lite_power_mode(mode) def set_trt_input_shape(self, @@ -111,6 +182,13 @@ class RuntimeOption: min_shape, opt_shape=None, max_shape=None): + """Set shape range information while using TensorRT backend with loadding a model contains dynamic input shape. While inference with a new input shape out of the set shape range, the tensorrt engine will be rebuilt to expand the shape range information. + + :param tensor_name: (str)Name of input which has dynamic shape + :param min_shape: (list of int)Minimum shape of the input, e.g [1, 3, 224, 224] + :param opt_shape: (list of int)Optimize shape of the input, this offten set as the most common input shape, if set to None, it will keep same with min_shape + :param max_shape: (list of int)Maximum shape of the input, e.g [8, 3, 224, 224], if set to None, it will keep same with the min_shape + """ if opt_shape is None and max_shape is None: opt_shape = min_shape max_shape = min_shape @@ -120,15 +198,25 @@ class RuntimeOption: opt_shape, max_shape) def set_trt_cache_file(self, cache_file_path): + """Set a cache file path while using TensorRT backend. While loading a Paddle/ONNX model with set_trt_cache_file("./tensorrt_cache/model.trt"), if file `./tensorrt_cache/model.trt` exists, it will skip building tensorrt engine and load the cache file directly; if file `./tensorrt_cache/model.trt` doesn't exist, it will building tensorrt engine and save the engine as binary string to the cache file. + + :param cache_file_path: (str)Path of tensorrt cache file + """ return self._option.set_trt_cache_file(cache_file_path) def enable_trt_fp16(self): + """Enable half precision inference while using TensorRT backend, notice that not all the Nvidia GPU support FP16, in those cases, will fallback to FP32 inference. + """ return self._option.enable_trt_fp16() def disable_trt_fp16(self): + """Disable half precision inference while suing TensorRT backend. + """ return self._option.disable_trt_fp16() def set_trt_max_workspace_size(self, trt_max_workspace_size): + """Set max workspace size while using TensorRT backend. + """ return self._option.set_trt_max_workspace_size(trt_max_workspace_size) def __repr__(self): @@ -139,8 +227,7 @@ class RuntimeOption: continue if hasattr(getattr(self._option, attr), "__call__"): continue - message += " {} : {}\t\n".format(attr, - getattr(self._option, attr)) + message += " {} : {}\t\n".format(attr, getattr(self._option, attr)) message.strip("\n") message += ")" return message diff --git a/python/fastdeploy/vision/classification/ppcls/__init__.py b/python/fastdeploy/vision/classification/ppcls/__init__.py index 256fa3936..5672e8efb 100644 --- a/python/fastdeploy/vision/classification/ppcls/__init__.py +++ b/python/fastdeploy/vision/classification/ppcls/__init__.py @@ -25,13 +25,29 @@ class PaddleClasModel(FastDeployModel): config_file, runtime_option=None, model_format=ModelFormat.PADDLE): + """Load a image classification model exported by PaddleClas. + + :param model_file: (str)Path of model file, e.g resnet50/inference.pdmodel + :param params_file: (str)Path of parameters file, e.g resnet50/inference.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param config_file: (str) Path of configuration file for deploy, e.g resnet50/inference_cls.yaml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + super(PaddleClasModel, self).__init__(runtime_option) - assert model_format == ModelFormat.PADDLE, "PaddleClasModel only support model format of ModelFormat.Paddle now." + assert model_format == ModelFormat.PADDLE, "PaddleClasModel only support model format of ModelFormat.PADDLE now." self._model = C.vision.classification.PaddleClasModel( model_file, params_file, config_file, self._runtime_option, model_format) assert self.initialized, "PaddleClas model initialize failed." - def predict(self, input_image, topk=1): - return self._model.predict(input_image, topk) + def predict(self, im, topk=1): + """Classify an input image + + :param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :param topk: (int)The topk result by the classify confidence score, default 1 + :return: ClassifyResult + """ + + return self._model.predict(im, topk) diff --git a/python/fastdeploy/vision/detection/ppdet/__init__.py b/python/fastdeploy/vision/detection/ppdet/__init__.py index cbade4336..4497c75ee 100644 --- a/python/fastdeploy/vision/detection/ppdet/__init__.py +++ b/python/fastdeploy/vision/detection/ppdet/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from __future__ import absolute_import +from typing import Union, List import logging from .... import FastDeployModel, ModelFormat from .... import c_lib_wrap as C @@ -25,6 +26,14 @@ class PPYOLOE(FastDeployModel): config_file, runtime_option=None, model_format=ModelFormat.PADDLE): + """Load a PPYOLOE model exported by PaddleDetection. + + :param model_file: (str)Path of model file, e.g ppyoloe/model.pdmodel + :param params_file: (str)Path of parameters file, e.g ppyoloe/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ super(PPYOLOE, self).__init__(runtime_option) assert model_format == ModelFormat.PADDLE, "PPYOLOE model only support model format of ModelFormat.Paddle now." @@ -33,9 +42,15 @@ class PPYOLOE(FastDeployModel): model_format) assert self.initialized, "PPYOLOE model initialize failed." - def predict(self, input_image): - assert input_image is not None, "The input image data is None." - return self._model.predict(input_image) + def predict(self, im): + """Detect an input image + + :param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format + :return: DetectionResult + """ + + assert im is not None, "The input image data is None." + return self._model.predict(im) class PPYOLO(PPYOLOE): @@ -45,6 +60,14 @@ class PPYOLO(PPYOLOE): config_file, runtime_option=None, model_format=ModelFormat.PADDLE): + """Load a PPYOLO model exported by PaddleDetection. + + :param model_file: (str)Path of model file, e.g ppyolo/model.pdmodel + :param params_file: (str)Path of parameters file, e.g ppyolo/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + super(PPYOLOE, self).__init__(runtime_option) assert model_format == ModelFormat.PADDLE, "PPYOLO model only support model format of ModelFormat.Paddle now." @@ -61,6 +84,15 @@ class PPYOLOv2(PPYOLOE): config_file, runtime_option=None, model_format=ModelFormat.PADDLE): + """Load a PPYOLOv2 model exported by PaddleDetection. + + :param model_file: (str)Path of model file, e.g ppyolov2/model.pdmodel + :param params_file: (str)Path of parameters file, e.g ppyolov2/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + super(PPYOLOE, self).__init__(runtime_option) assert model_format == ModelFormat.PADDLE, "PPYOLOv2 model only support model format of ModelFormat.Paddle now." @@ -77,6 +109,15 @@ class PaddleYOLOX(PPYOLOE): config_file, runtime_option=None, model_format=ModelFormat.PADDLE): + """Load a YOLOX model exported by PaddleDetection. + + :param model_file: (str)Path of model file, e.g yolox/model.pdmodel + :param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + super(PPYOLOE, self).__init__(runtime_option) assert model_format == ModelFormat.PADDLE, "PaddleYOLOX model only support model format of ModelFormat.Paddle now." @@ -93,6 +134,15 @@ class PicoDet(PPYOLOE): config_file, runtime_option=None, model_format=ModelFormat.PADDLE): + """Load a PicoDet model exported by PaddleDetection. + + :param model_file: (str)Path of model file, e.g picodet/model.pdmodel + :param params_file: (str)Path of parameters file, e.g picodet/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + super(PPYOLOE, self).__init__(runtime_option) assert model_format == ModelFormat.PADDLE, "PicoDet model only support model format of ModelFormat.Paddle now." @@ -109,6 +159,15 @@ class FasterRCNN(PPYOLOE): config_file, runtime_option=None, model_format=ModelFormat.PADDLE): + """Load a FasterRCNN model exported by PaddleDetection. + + :param model_file: (str)Path of model file, e.g fasterrcnn/model.pdmodel + :param params_file: (str)Path of parameters file, e.g fasterrcnn/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + super(PPYOLOE, self).__init__(runtime_option) assert model_format == ModelFormat.PADDLE, "FasterRCNN model only support model format of ModelFormat.Paddle now." @@ -125,6 +184,15 @@ class YOLOv3(PPYOLOE): config_file, runtime_option=None, model_format=ModelFormat.PADDLE): + """Load a YOLOv3 model exported by PaddleDetection. + + :param model_file: (str)Path of model file, e.g yolov3/model.pdmodel + :param params_file: (str)Path of parameters file, e.g yolov3/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + super(PPYOLOE, self).__init__(runtime_option) assert model_format == ModelFormat.PADDLE, "YOLOv3 model only support model format of ModelFormat.Paddle now." @@ -141,6 +209,15 @@ class MaskRCNN(FastDeployModel): config_file, runtime_option=None, model_format=ModelFormat.PADDLE): + """Load a MaskRCNN model exported by PaddleDetection. + + :param model_file: (str)Path of model file, e.g maskrcnn/model.pdmodel + :param params_file: (str)Path of parameters file, e.g maskrcnn/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string + :param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model + """ + super(MaskRCNN, self).__init__(runtime_option) assert model_format == ModelFormat.PADDLE, "MaskRCNN model only support model format of ModelFormat.Paddle now."