Add some comments for python api (#327)

* Add some comments for python api

* Update setup.py

* Update runtime.py
This commit is contained in:
Jason
2022-10-09 10:05:18 +08:00
committed by GitHub
parent a3fa5989d2
commit 5d4372955f
11 changed files with 239 additions and 13 deletions

4
.new_docs/api.md Normal file
View File

@@ -0,0 +1,4 @@
# API说明
- [Python API](./python_apis/index.rst)
- [C++ API](https://paddlepaddle.github.io/FastDeploy/)

View File

@@ -14,3 +14,4 @@ FastDeploy
build_and_install/index build_and_install/index
quick_start/index quick_start/index
api.md

View File

@@ -0,0 +1,9 @@
# 图像分类模型部署
## fastdeploy.vision.classification.PaddleClasModel
```{eval-rst}
.. autoclass:: fastdeploy.vision.classification.PaddleClasModel
:members:
:inherited-members:
```

View File

@@ -0,0 +1,13 @@
Python API
=======================================
FastDeploy支持通过Python编程语言进行部署
.. toctree::
:caption: Python API
:maxdepth: 3
:titlesonly:
image_classification.md
object_detection.md
runtime.md

View File

@@ -1,4 +1,4 @@
# Object Detection API Reference # 目标检测模型部署
## fastdeploy.vision.detection.PPYOLOE ## fastdeploy.vision.detection.PPYOLOE

View File

@@ -0,0 +1,19 @@
# Runtime模块使用
FastDeploy Runtime模块可单独使用通过同样的代码可快速完成Paddle/ONNX模型在不同硬件后端上的推理加速部署。
## fastdeploy.RuntimeOption
```{eval-rst}
.. autoclass:: fastdeploy.RuntimeOption
:members:
:inherited-members:
```
## fastdeploy.Runtime
```{eval-rst}
.. autoclass:: fastdeploy.Runtime
:members:
:inherited-members:
```

View File

@@ -1 +1 @@
0.2.1 0.3.0rc

View File

@@ -13,27 +13,50 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
import logging import logging
from . import ModelFormat
from . import c_lib_wrap as C from . import c_lib_wrap as C
class Runtime: class Runtime:
"""FastDeploy Runtime object.
"""
def __init__(self, runtime_option): def __init__(self, runtime_option):
"""Initialize a FastDeploy Runtime object.
:param runtime_option: (fastdeploy.RuntimeOption)Options for FastDeploy Runtime
"""
self._runtime = C.Runtime() self._runtime = C.Runtime()
assert self._runtime.init( assert self._runtime.init(
runtime_option._option), "Initialize Runtime Failed!" runtime_option._option), "Initialize Runtime Failed!"
def infer(self, data): def infer(self, data):
"""Inference with input data.
:param data: (dict[str : numpy.ndarray])The input data dict, key value must keep same with the loaded model
:return list of numpy.ndarray
"""
assert isinstance(data, dict) or isinstance( assert isinstance(data, dict) or isinstance(
data, list), "The input data should be type of dict or list." data, list), "The input data should be type of dict or list."
return self._runtime.infer(data) return self._runtime.infer(data)
def num_inputs(self): def num_inputs(self):
"""Get number of inputs of the loaded model.
"""
return self._runtime.num_inputs() return self._runtime.num_inputs()
def num_outputs(self): def num_outputs(self):
"""Get number of outputs of the loaded model.
"""
return self._runtime.num_outputs() return self._runtime.num_outputs()
def get_input_info(self, index): def get_input_info(self, index):
"""Get input information of the loaded model.
:param index: (int)Index of the input
:return fastdeploy.TensorInfo
"""
assert isinstance( assert isinstance(
index, int), "The input parameter index should be type of int." index, int), "The input parameter index should be type of int."
assert index < self.num_inputs( assert index < self.num_inputs(
@@ -42,6 +65,11 @@ class Runtime:
return self._runtime.get_input_info(index) return self._runtime.get_input_info(index)
def get_output_info(self, index): def get_output_info(self, index):
"""Get output information of the loaded model.
:param index: (int)Index of the output
:return fastdeploy.TensorInfo
"""
assert isinstance( assert isinstance(
index, int), "The input parameter index should be type of int." index, int), "The input parameter index should be type of int."
assert index < self.num_outputs( assert index < self.num_outputs(
@@ -51,59 +79,102 @@ class Runtime:
class RuntimeOption: class RuntimeOption:
"""Options for FastDeploy Runtime.
"""
def __init__(self): def __init__(self):
self._option = C.RuntimeOption() self._option = C.RuntimeOption()
def set_model_path(self, def set_model_path(self,
model_path, model_path,
params_path="", params_path="",
model_format=C.ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Set path of model file and parameters file
:param model_path: (str)Path of model file
:param params_path: (str)Path of parameters file
:param model_format: (ModelFormat)Format of model, support ModelFormat.PADDLE/ModelFormat.ONNX
"""
return self._option.set_model_path(model_path, params_path, return self._option.set_model_path(model_path, params_path,
model_format) model_format)
def use_gpu(self, device_id=0): def use_gpu(self, device_id=0):
"""Inference with Nvidia GPU
:param device_id: (int)The index of GPU will be used for inference, default 0
"""
return self._option.use_gpu(device_id) return self._option.use_gpu(device_id)
def use_cpu(self): def use_cpu(self):
"""Inference with CPU
"""
return self._option.use_cpu() return self._option.use_cpu()
def set_cpu_thread_num(self, thread_num=-1): def set_cpu_thread_num(self, thread_num=-1):
"""Set number of threads if inference with CPU
:param thread_num: (int)Number of threads, if not positive, means the number of threads is decided by the backend, default -1
"""
return self._option.set_cpu_thread_num(thread_num) return self._option.set_cpu_thread_num(thread_num)
def use_paddle_backend(self): def use_paddle_backend(self):
"""Use Paddle Inference backend, support inference Paddle model on CPU/Nvidia GPU.
"""
return self._option.use_paddle_backend() return self._option.use_paddle_backend()
def use_ort_backend(self): def use_ort_backend(self):
"""Use ONNX Runtime backend, support inference Paddle/ONNX model on CPU/Nvidia GPU.
"""
return self._option.use_ort_backend() return self._option.use_ort_backend()
def use_trt_backend(self): def use_trt_backend(self):
"""Use TensorRT backend, support inference Paddle/ONNX model on Nvidia GPU.
"""
return self._option.use_trt_backend() return self._option.use_trt_backend()
def use_openvino_backend(self): def use_openvino_backend(self):
"""Use OpenVINO backend, support inference Paddle/ONNX model on CPU.
"""
return self._option.use_openvino_backend() return self._option.use_openvino_backend()
def use_lite_backend(self): def use_lite_backend(self):
"""Use Paddle Lite backend, support inference Paddle model on ARM CPU.
"""
return self._option.use_lite_backend() return self._option.use_lite_backend()
def set_paddle_mkldnn(self, pd_mkldnn=True): def set_paddle_mkldnn(self, use_mkldnn=True):
return self._option.set_paddle_mkldnn(pd_mkldnn) """Enable/Disable MKLDNN while using Paddle Inference backend, mkldnn is enabled by default.
"""
return self._option.set_paddle_mkldnn(use_mkldnn)
def enable_paddle_log_info(self): def enable_paddle_log_info(self):
"""Enable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
"""
return self._option.enable_paddle_log_info() return self._option.enable_paddle_log_info()
def disable_paddle_log_info(self): def disable_paddle_log_info(self):
"""Disable print out the debug log information while using Paddle Inference backend, the log information is disabled by default.
"""
return self._option.disable_paddle_log_info() return self._option.disable_paddle_log_info()
def set_paddle_mkldnn_cache_size(self, cache_size): def set_paddle_mkldnn_cache_size(self, cache_size):
"""Set size of shape cache while using Paddle Inference backend with MKLDNN enabled, default will cache all the dynamic shape.
"""
return self._option.set_paddle_mkldnn_cache_size(cache_size) return self._option.set_paddle_mkldnn_cache_size(cache_size)
def enable_lite_fp16(self): def enable_lite_fp16(self):
"""Enable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default.
"""
return self._option.enable_lite_fp16() return self._option.enable_lite_fp16()
def disable_lite_fp16(self): def disable_lite_fp16(self):
"""Disable half precision inference while using Paddle Lite backend on ARM CPU, fp16 is disabled by default.
"""
return self._option.disable_lite_fp16() return self._option.disable_lite_fp16()
def set_lite_power_mode(self, mode): def set_lite_power_mode(self, mode):
"""Set POWER mode while using Paddle Lite backend on ARM CPU.
"""
return self._option.set_lite_power_mode(mode) return self._option.set_lite_power_mode(mode)
def set_trt_input_shape(self, def set_trt_input_shape(self,
@@ -111,6 +182,13 @@ class RuntimeOption:
min_shape, min_shape,
opt_shape=None, opt_shape=None,
max_shape=None): max_shape=None):
"""Set shape range information while using TensorRT backend with loadding a model contains dynamic input shape. While inference with a new input shape out of the set shape range, the tensorrt engine will be rebuilt to expand the shape range information.
:param tensor_name: (str)Name of input which has dynamic shape
:param min_shape: (list of int)Minimum shape of the input, e.g [1, 3, 224, 224]
:param opt_shape: (list of int)Optimize shape of the input, this offten set as the most common input shape, if set to None, it will keep same with min_shape
:param max_shape: (list of int)Maximum shape of the input, e.g [8, 3, 224, 224], if set to None, it will keep same with the min_shape
"""
if opt_shape is None and max_shape is None: if opt_shape is None and max_shape is None:
opt_shape = min_shape opt_shape = min_shape
max_shape = min_shape max_shape = min_shape
@@ -120,15 +198,25 @@ class RuntimeOption:
opt_shape, max_shape) opt_shape, max_shape)
def set_trt_cache_file(self, cache_file_path): def set_trt_cache_file(self, cache_file_path):
"""Set a cache file path while using TensorRT backend. While loading a Paddle/ONNX model with set_trt_cache_file("./tensorrt_cache/model.trt"), if file `./tensorrt_cache/model.trt` exists, it will skip building tensorrt engine and load the cache file directly; if file `./tensorrt_cache/model.trt` doesn't exist, it will building tensorrt engine and save the engine as binary string to the cache file.
:param cache_file_path: (str)Path of tensorrt cache file
"""
return self._option.set_trt_cache_file(cache_file_path) return self._option.set_trt_cache_file(cache_file_path)
def enable_trt_fp16(self): def enable_trt_fp16(self):
"""Enable half precision inference while using TensorRT backend, notice that not all the Nvidia GPU support FP16, in those cases, will fallback to FP32 inference.
"""
return self._option.enable_trt_fp16() return self._option.enable_trt_fp16()
def disable_trt_fp16(self): def disable_trt_fp16(self):
"""Disable half precision inference while suing TensorRT backend.
"""
return self._option.disable_trt_fp16() return self._option.disable_trt_fp16()
def set_trt_max_workspace_size(self, trt_max_workspace_size): def set_trt_max_workspace_size(self, trt_max_workspace_size):
"""Set max workspace size while using TensorRT backend.
"""
return self._option.set_trt_max_workspace_size(trt_max_workspace_size) return self._option.set_trt_max_workspace_size(trt_max_workspace_size)
def __repr__(self): def __repr__(self):
@@ -139,8 +227,7 @@ class RuntimeOption:
continue continue
if hasattr(getattr(self._option, attr), "__call__"): if hasattr(getattr(self._option, attr), "__call__"):
continue continue
message += " {} : {}\t\n".format(attr, message += " {} : {}\t\n".format(attr, getattr(self._option, attr))
getattr(self._option, attr))
message.strip("\n") message.strip("\n")
message += ")" message += ")"
return message return message

View File

@@ -25,13 +25,29 @@ class PaddleClasModel(FastDeployModel):
config_file, config_file,
runtime_option=None, runtime_option=None,
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load a image classification model exported by PaddleClas.
:param model_file: (str)Path of model file, e.g resnet50/inference.pdmodel
:param params_file: (str)Path of parameters file, e.g resnet50/inference.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str) Path of configuration file for deploy, e.g resnet50/inference_cls.yaml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PaddleClasModel, self).__init__(runtime_option) super(PaddleClasModel, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "PaddleClasModel only support model format of ModelFormat.Paddle now." assert model_format == ModelFormat.PADDLE, "PaddleClasModel only support model format of ModelFormat.PADDLE now."
self._model = C.vision.classification.PaddleClasModel( self._model = C.vision.classification.PaddleClasModel(
model_file, params_file, config_file, self._runtime_option, model_file, params_file, config_file, self._runtime_option,
model_format) model_format)
assert self.initialized, "PaddleClas model initialize failed." assert self.initialized, "PaddleClas model initialize failed."
def predict(self, input_image, topk=1): def predict(self, im, topk=1):
return self._model.predict(input_image, topk) """Classify an input image
:param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:param topk: (int)The topk result by the classify confidence score, default 1
:return: ClassifyResult
"""
return self._model.predict(im, topk)

View File

@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from typing import Union, List
import logging import logging
from .... import FastDeployModel, ModelFormat from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C from .... import c_lib_wrap as C
@@ -25,6 +26,14 @@ class PPYOLOE(FastDeployModel):
config_file, config_file,
runtime_option=None, runtime_option=None,
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load a PPYOLOE model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g ppyoloe/model.pdmodel
:param params_file: (str)Path of parameters file, e.g ppyoloe/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option) super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "PPYOLOE model only support model format of ModelFormat.Paddle now." assert model_format == ModelFormat.PADDLE, "PPYOLOE model only support model format of ModelFormat.Paddle now."
@@ -33,9 +42,15 @@ class PPYOLOE(FastDeployModel):
model_format) model_format)
assert self.initialized, "PPYOLOE model initialize failed." assert self.initialized, "PPYOLOE model initialize failed."
def predict(self, input_image): def predict(self, im):
assert input_image is not None, "The input image data is None." """Detect an input image
return self._model.predict(input_image)
:param im: (numpy.ndarray)The input image data, 3-D array with layout HWC, BGR format
:return: DetectionResult
"""
assert im is not None, "The input image data is None."
return self._model.predict(im)
class PPYOLO(PPYOLOE): class PPYOLO(PPYOLOE):
@@ -45,6 +60,14 @@ class PPYOLO(PPYOLOE):
config_file, config_file,
runtime_option=None, runtime_option=None,
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load a PPYOLO model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g ppyolo/model.pdmodel
:param params_file: (str)Path of parameters file, e.g ppyolo/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option) super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "PPYOLO model only support model format of ModelFormat.Paddle now." assert model_format == ModelFormat.PADDLE, "PPYOLO model only support model format of ModelFormat.Paddle now."
@@ -61,6 +84,15 @@ class PPYOLOv2(PPYOLOE):
config_file, config_file,
runtime_option=None, runtime_option=None,
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load a PPYOLOv2 model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g ppyolov2/model.pdmodel
:param params_file: (str)Path of parameters file, e.g ppyolov2/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option) super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "PPYOLOv2 model only support model format of ModelFormat.Paddle now." assert model_format == ModelFormat.PADDLE, "PPYOLOv2 model only support model format of ModelFormat.Paddle now."
@@ -77,6 +109,15 @@ class PaddleYOLOX(PPYOLOE):
config_file, config_file,
runtime_option=None, runtime_option=None,
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load a YOLOX model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g yolox/model.pdmodel
:param params_file: (str)Path of parameters file, e.g yolox/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option) super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "PaddleYOLOX model only support model format of ModelFormat.Paddle now." assert model_format == ModelFormat.PADDLE, "PaddleYOLOX model only support model format of ModelFormat.Paddle now."
@@ -93,6 +134,15 @@ class PicoDet(PPYOLOE):
config_file, config_file,
runtime_option=None, runtime_option=None,
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load a PicoDet model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g picodet/model.pdmodel
:param params_file: (str)Path of parameters file, e.g picodet/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option) super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "PicoDet model only support model format of ModelFormat.Paddle now." assert model_format == ModelFormat.PADDLE, "PicoDet model only support model format of ModelFormat.Paddle now."
@@ -109,6 +159,15 @@ class FasterRCNN(PPYOLOE):
config_file, config_file,
runtime_option=None, runtime_option=None,
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load a FasterRCNN model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g fasterrcnn/model.pdmodel
:param params_file: (str)Path of parameters file, e.g fasterrcnn/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option) super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "FasterRCNN model only support model format of ModelFormat.Paddle now." assert model_format == ModelFormat.PADDLE, "FasterRCNN model only support model format of ModelFormat.Paddle now."
@@ -125,6 +184,15 @@ class YOLOv3(PPYOLOE):
config_file, config_file,
runtime_option=None, runtime_option=None,
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load a YOLOv3 model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g yolov3/model.pdmodel
:param params_file: (str)Path of parameters file, e.g yolov3/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(PPYOLOE, self).__init__(runtime_option) super(PPYOLOE, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "YOLOv3 model only support model format of ModelFormat.Paddle now." assert model_format == ModelFormat.PADDLE, "YOLOv3 model only support model format of ModelFormat.Paddle now."
@@ -141,6 +209,15 @@ class MaskRCNN(FastDeployModel):
config_file, config_file,
runtime_option=None, runtime_option=None,
model_format=ModelFormat.PADDLE): model_format=ModelFormat.PADDLE):
"""Load a MaskRCNN model exported by PaddleDetection.
:param model_file: (str)Path of model file, e.g maskrcnn/model.pdmodel
:param params_file: (str)Path of parameters file, e.g maskrcnn/model.pdiparams, if the model_fomat is ModelFormat.ONNX, this param will be ignored, can be set as empty string
:param config_file: (str)Path of configuration file for deployment, e.g ppyoloe/infer_cfg.yml
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model
"""
super(MaskRCNN, self).__init__(runtime_option) super(MaskRCNN, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "MaskRCNN model only support model format of ModelFormat.Paddle now." assert model_format == ModelFormat.PADDLE, "MaskRCNN model only support model format of ModelFormat.Paddle now."