diff --git a/examples/vision/detection/paddledetection/python/serving/README.md b/examples/vision/detection/paddledetection/python/serving/README.md new file mode 120000 index 000000000..bacd3186b --- /dev/null +++ b/examples/vision/detection/paddledetection/python/serving/README.md @@ -0,0 +1 @@ +README_CN.md \ No newline at end of file diff --git a/examples/vision/detection/paddledetection/python/serving/README_CN.md b/examples/vision/detection/paddledetection/python/serving/README_CN.md new file mode 100644 index 000000000..b9bfe008d --- /dev/null +++ b/examples/vision/detection/paddledetection/python/serving/README_CN.md @@ -0,0 +1,43 @@ +简体中文 | [English](README_EN.md) + +# PaddleDetection Python轻量服务化部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +服务端: +```bash +# 下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/detection/paddledetection/python/serving + +# 下载PPYOLOE模型文件(如果不下载,代码里会自动从hub下载) +wget https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz +tar xvf ppyoloe_crn_l_300e_coco.tgz + +# 安装uvicorn +pip install uvicorn + +# 启动服务,可选择是否使用GPU和TensorRT,可根据uvicorn --help配置IP、端口号等 +# CPU +MODEL_DIR=ppyoloe_crn_l_300e_coco DEVICE=cpu uvicorn server:app +# GPU +MODEL_DIR=ppyoloe_crn_l_300e_coco DEVICE=gpu uvicorn server:app +# GPU上使用TensorRT (注意:TensorRT推理第一次运行,有序列化模型的操作,有一定耗时,需要耐心等待) +MODEL_DIR=ppyoloe_crn_l_300e_coco DEVICE=gpu USE_TRT=true uvicorn server:app +``` + +客户端: +```bash +# 下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/detection/paddledetection/python/serving + +# 下载测试图片 +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + +# 请求服务,获取推理结果(如有必要,请修改脚本中的IP和端口号) +python client.py +``` diff --git a/examples/vision/detection/paddledetection/python/serving/README_EN.md b/examples/vision/detection/paddledetection/python/serving/README_EN.md new file mode 100644 index 000000000..7ccf087c0 --- /dev/null +++ b/examples/vision/detection/paddledetection/python/serving/README_EN.md @@ -0,0 +1,44 @@ +English | [简体中文](README_CN.md) + +# PaddleDetection Python Simple Serving Demo + + +## Environment + +- 1. Prepare environment and install FastDeploy Python whl, refer to [download_prebuilt_libraries](../../../../../../docs/en/build_and_install/download_prebuilt_libraries.md) + +Server: +```bash +# Download demo code +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/detection/paddledetection/python/serving + +# Download PPYOLOE model +wget https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz +tar xvf ppyoloe_crn_l_300e_coco.tgz + +# Install uvicorn +pip install uvicorn + +# Launch server, it's configurable to use GPU and TensorRT, +# and run 'uvicorn --help' to check how to specify IP and port, etc. +# CPU +MODEL_DIR=ppyoloe_crn_l_300e_coco DEVICE=cpu uvicorn server:app +# GPU +MODEL_DIR=ppyoloe_crn_l_300e_coco DEVICE=gpu uvicorn server:app +# GPU and TensorRT +MODEL_DIR=ppyoloe_crn_l_300e_coco DEVICE=gpu USE_TRT=true uvicorn server:app +``` + +Client: +```bash +# Download demo code +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/detection/paddledetection/python/serving + +# Download test image +wget https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg + +# Send request and get inference result (Please adapt the IP and port if necessary) +python client.py +``` diff --git a/examples/vision/detection/paddledetection/python/serving/client.py b/examples/vision/detection/paddledetection/python/serving/client.py new file mode 100644 index 000000000..c10f61976 --- /dev/null +++ b/examples/vision/detection/paddledetection/python/serving/client.py @@ -0,0 +1,28 @@ +import requests +import json +import cv2 +import base64 +import fastdeploy as fd + +if __name__ == '__main__': + url = "http://127.0.0.1:8000/fd/ppyoloe" + headers = {"Content-Type": "application/json"} + + im = cv2.imread("000000014439.jpg") + data = { + "data": { + "image": fd.serving.utils.cv2_to_base64(im) + }, + "parameters": {} + } + + resp = requests.post(url=url, headers=headers, data=json.dumps(data)) + if resp.status_code == 200: + r_json = json.loads(resp.json()["result"]) + det_result = fd.vision.utils.json_to_detection(r_json) + vis_im = fd.vision.vis_detection(im, det_result, score_threshold=0.5) + cv2.imwrite("visualized_result.jpg", vis_im) + print("Visualized result save in ./visualized_result.jpg") + else: + print("Error code:", resp.status_code) + print(resp.text) diff --git a/examples/vision/detection/paddledetection/python/serving/server.py b/examples/vision/detection/paddledetection/python/serving/server.py new file mode 100644 index 000000000..5127cbd44 --- /dev/null +++ b/examples/vision/detection/paddledetection/python/serving/server.py @@ -0,0 +1,40 @@ +import fastdeploy as fd +import os +import logging + +logging.getLogger().setLevel(logging.INFO) + +# Get arguments from envrionment variables +model_dir = os.environ.get('MODEL_DIR') +device = os.environ.get('DEVICE', 'cpu') +use_trt = os.environ.get('USE_TRT', False) + +# Prepare model, download from hub or use local dir +if model_dir is None: + model_dir = fd.download_model(name='ppyoloe_crn_l_300e_coco') + +model_file = os.path.join(model_dir, "model.pdmodel") +params_file = os.path.join(model_dir, "model.pdiparams") +config_file = os.path.join(model_dir, "infer_cfg.yml") + +# Setup runtime option to select hardware, backend, etc. +option = fd.RuntimeOption() +if device.lower() == 'gpu': + option.use_gpu() +if use_trt: + option.use_trt_backend() + option.set_trt_cache_file('ppyoloe.trt') + +# Create model instance +model_instance = fd.vision.detection.PPYOLOE( + model_file=model_file, + params_file=params_file, + config_file=config_file, + runtime_option=option) + +# Create server, setup REST API +app = fd.serving.SimpleServer() +app.register( + task_name="fd/ppyoloe", + model_handler=fd.serving.handler.VisionModelHandler, + predictor=model_instance) diff --git a/python/fastdeploy/__init__.py b/python/fastdeploy/__init__.py index b767393f1..42db5c281 100644 --- a/python/fastdeploy/__init__.py +++ b/python/fastdeploy/__init__.py @@ -37,3 +37,4 @@ from . import vision from . import pipeline from . import text from .download import download, download_and_decompress, download_model +from . import serving diff --git a/python/fastdeploy/serving/__init__.py b/python/fastdeploy/serving/__init__.py new file mode 100644 index 000000000..e40cbb900 --- /dev/null +++ b/python/fastdeploy/serving/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import + +from .server import SimpleServer diff --git a/python/fastdeploy/serving/handler/__init__.py b/python/fastdeploy/serving/handler/__init__.py new file mode 100644 index 000000000..a1e40793c --- /dev/null +++ b/python/fastdeploy/serving/handler/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from .base_handler import BaseModelHandler +from .vision_model_handler import VisionModelHandler diff --git a/python/fastdeploy/serving/handler/base_handler.py b/python/fastdeploy/serving/handler/base_handler.py new file mode 100644 index 000000000..ab6a34427 --- /dev/null +++ b/python/fastdeploy/serving/handler/base_handler.py @@ -0,0 +1,28 @@ +# coding:utf-8 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from abc import ABCMeta, abstractmethod + + +class BaseModelHandler(metaclass=ABCMeta): + def __init__(self): + super().__init__() + + @classmethod + @abstractmethod + def process(cls, predictor, data, parameters): + pass + diff --git a/python/fastdeploy/serving/handler/vision_model_handler.py b/python/fastdeploy/serving/handler/vision_model_handler.py new file mode 100644 index 000000000..dc14c0c3f --- /dev/null +++ b/python/fastdeploy/serving/handler/vision_model_handler.py @@ -0,0 +1,30 @@ +# coding:utf-8 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .base_handler import BaseModelHandler +from ..utils import base64_to_cv2 +from ...vision.utils import fd_result_to_json + + +class VisionModelHandler(BaseModelHandler): + def __init__(self): + super().__init__() + + @classmethod + def process(cls, predictor, data, parameters): + # TODO: support batch predict + im = base64_to_cv2(data['image']) + result = predictor.predict(im) + r_str = fd_result_to_json(result) + return r_str diff --git a/python/fastdeploy/serving/model_manager.py b/python/fastdeploy/serving/model_manager.py new file mode 100644 index 000000000..ed252d133 --- /dev/null +++ b/python/fastdeploy/serving/model_manager.py @@ -0,0 +1,57 @@ +# coding:utf-8 +# copyright (c) 2022 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license" +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import os +import time +import json +import logging +import threading +# from .predictor import Predictor +from .handler import BaseModelHandler +from .utils import lock_predictor + + +class ModelManager: + def __init__(self, model_handler, predictor): + self._model_handler = model_handler + self._predictors = [] + self._predictor_locks = [] + self._register(predictor) + + def _register(self, predictor): + # Get the model handler + if not issubclass(self._model_handler, BaseModelHandler): + raise TypeError( + "The model_handler must be subclass of BaseModelHandler, please check the type." + ) + + # TODO: Create multiple predictors to run on different GPUs or different CPU threads + self._predictors.append(predictor) + self._predictor_locks.append(threading.Lock()) + + def _get_predict_id(self): + t = time.time() + t = int(round(t * 1000)) + predictor_id = t % len(self._predictors) + logging.info("The predictor id: {} is selected by running the model.". + format(predictor_id)) + return predictor_id + + def predict(self, data, parameters): + predictor_id = self._get_predict_id() + with lock_predictor(self._predictor_locks[predictor_id]): + model_output = self._model_handler.process( + self._predictors[predictor_id], data, parameters) + return model_output diff --git a/python/fastdeploy/serving/router/__init__.py b/python/fastdeploy/serving/router/__init__.py new file mode 100644 index 000000000..c3ee45631 --- /dev/null +++ b/python/fastdeploy/serving/router/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from .base_router import BaseRouterManager +from .http_router import HttpRouterManager diff --git a/python/fastdeploy/serving/router/base_router.py b/python/fastdeploy/serving/router/base_router.py new file mode 100644 index 000000000..986d31b5f --- /dev/null +++ b/python/fastdeploy/serving/router/base_router.py @@ -0,0 +1,28 @@ +# coding:utf-8 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + + +class BaseRouterManager(abc.ABC): + _app = None + + def __init__(self, app): + super().__init__() + self._app = app + + @abc.abstractmethod + def register_models_router(self): + return NotImplemented diff --git a/python/fastdeploy/serving/router/http_router.py b/python/fastdeploy/serving/router/http_router.py new file mode 100644 index 000000000..b35640f89 --- /dev/null +++ b/python/fastdeploy/serving/router/http_router.py @@ -0,0 +1,80 @@ +# coding:utf-8 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hashlib +import typing +import logging +from typing import Optional + +from fastapi import APIRouter, Request, HTTPException +from pydantic import BaseModel, Extra, create_model + +from .base_router import BaseRouterManager + + +class ResponseBase(BaseModel): + text: Optional[str] = None + + +class RequestBase(BaseModel, extra=Extra.forbid): + parameters: Optional[dict] = {} + + +class HttpRouterManager(BaseRouterManager): + def register_models_router(self, task_name): + + # Url path to register the model + paths = [f"/{task_name}"] + for path in paths: + logging.info("FastDeploy Model request [path]={} is genereated.". + format(path)) + + # Unique name to create the pydantic model + unique_name = hashlib.md5(task_name.encode()).hexdigest() + + # Create request model + req_model = create_model( + "RequestModel" + unique_name, + data=(typing.Any, ...), + __base__=RequestBase, ) + + # Create response model + resp_model = create_model( + "ResponseModel" + unique_name, + result=(typing.Any, ...), + __base__=ResponseBase, ) + + # Template predict endpoint function to dynamically serve different models + def predict(request: Request, inference_request: req_model): + try: + result = self._app._model_manager.predict( + inference_request.data, inference_request.parameters) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Error occurred while running predict: {str(e)}") + return {"result": result} + + # Register the route and add to the app + router = APIRouter() + for path in paths: + router.add_api_route( + path, + predict, + methods=["post"], + summary=f"{task_name.title()}", + response_model=resp_model, + response_model_exclude_unset=True, + response_model_exclude_none=True, ) + self._app.include_router(router) diff --git a/python/fastdeploy/serving/server.py b/python/fastdeploy/serving/server.py new file mode 100644 index 000000000..9f43d8592 --- /dev/null +++ b/python/fastdeploy/serving/server.py @@ -0,0 +1,46 @@ +# coding:utf-8 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from fastapi import FastAPI +from .router import HttpRouterManager +from .model_manager import ModelManager + + +class SimpleServer(FastAPI): + def __init__(self, **kwargs): + """ + Initial function for the FastDeploy SimpleServer. + """ + super().__init__(**kwargs) + self._router_manager = HttpRouterManager(self) + self._model_manager = None + self._service_name = "FastDeploy SimpleServer" + self._service_type = None + + def register(self, task_name, model_handler, predictor): + """ + The register function for the SimpleServer, the main register argrument as follows: + + Args: + task_name(str): API URL path. + model_handler: To process request data, run predictor, + and can also add your custom post processing on top of the predictor result + predictor: To run model predict + """ + self._server_type = "models" + model_manager = ModelManager(model_handler, predictor) + self._model_manager = model_manager + # Register model server router + self._router_manager.register_models_router(task_name) diff --git a/python/fastdeploy/serving/utils.py b/python/fastdeploy/serving/utils.py new file mode 100644 index 000000000..405ad5a20 --- /dev/null +++ b/python/fastdeploy/serving/utils.py @@ -0,0 +1,40 @@ +# coding:utf-8 +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import base64 +import numpy as np +import cv2 + + +@contextlib.contextmanager +def lock_predictor(lock): + lock.acquire() + try: + yield + finally: + lock.release() + + +def cv2_to_base64(image): + data = cv2.imencode('.jpg', image)[1] + return base64.b64encode(data.tobytes()).decode('utf8') + + +def base64_to_cv2(b64str): + data = base64.b64decode(b64str.encode('utf8')) + data = np.fromstring(data, np.uint8) + data = cv2.imdecode(data, cv2.IMREAD_COLOR) + return data diff --git a/python/fastdeploy/vision/utils.py b/python/fastdeploy/vision/utils.py index d4cc30357..f4e06cf27 100644 --- a/python/fastdeploy/vision/utils.py +++ b/python/fastdeploy/vision/utils.py @@ -46,6 +46,81 @@ def classify_to_json(result): return json.dumps(r_json) +def keypoint_to_json(result): + r_json = { + "keypoints": result.keypoints, + "scores": result.scores, + "num_joints": result.num_joints, + } + return json.dumps(r_json) + + +def ocr_to_json(result): + r_json = { + "boxes": result.boxes, + "text": result.text, + "rec_scores": result.rec_scores, + "cls_scores": result.cls_scores, + "cls_labels": result.cls_labels, + } + return json.dumps(r_json) + + +def mot_to_json(result): + r_json = { + "boxes": result.boxes, + "ids": result.ids, + "scores": result.scores, + "class_ids": result.class_ids, + } + return json.dumps(r_json) + + +def face_detection_to_json(result): + r_json = { + "boxes": result.boxes, + "landmarks": result.landmarks, + "scores": result.scores, + "landmarks_per_face": result.landmarks_per_face, + } + return json.dumps(r_json) + + +def face_alignment_to_json(result): + r_json = {"landmarks": result.landmarks, } + return json.dumps(r_json) + + +def face_recognition_to_json(result): + r_json = {"embedding": result.embedding, } + return json.dumps(r_json) + + +def segmentation_to_json(result): + r_json = { + "label_map": result.label_map, + "score_map": result.score_map, + "shape": result.shape, + "contain_score_map": result.contain_score_map, + } + return json.dumps(r_json) + + +def matting_to_json(result): + r_json = { + "alpha": result.alpha, + "foreground": result.foreground, + "shape": result.shape, + "contain_foreground": result.contain_foreground, + } + return json.dumps(r_json) + + +def head_pose_to_json(result): + r_json = {"euler_angles": result.euler_angles, } + return json.dumps(r_json) + + def fd_result_to_json(result): if isinstance(result, list): r_list = [] @@ -58,7 +133,124 @@ def fd_result_to_json(result): return mask_to_json(result) elif isinstance(result, C.vision.ClassifyResult): return classify_to_json(result) + elif isinstance(result, C.vision.KeyPointDetectionResult): + return keypoint_to_json(result) + elif isinstance(result, C.vision.OCRResult): + return ocr_to_json(result) + elif isinstance(result, C.vision.MOTResult): + return mot_to_json(result) + elif isinstance(result, C.vision.FaceDetectionResult): + return face_detection_to_json(result) + elif isinstance(result, C.vision.FaceAlignmentResult): + return face_alignment_to_json(result) + elif isinstance(result, C.vision.FaceRecognitionResult): + return face_recognition_to_json(result) + elif isinstance(result, C.vision.SegmentationResult): + return segmentation_to_json(result) + elif isinstance(result, C.vision.MattingResult): + return matting_to_json(result) + elif isinstance(result, C.vision.HeadPoseResult): + return head_pose_to_json(result) else: assert False, "{} Conversion to JSON format is not supported".format( type(result)) return {} + + +def json_to_mask(result): + mask = C.vision.Mask() + mask.data = result['data'] + mask.shape = result['shape'] + return mask + + +def json_to_detection(result): + masks = [] + for mask in result['masks']: + masks.append(json_to_mask(json.loads(mask))) + det_result = C.vision.DetectionResult() + det_result.boxes = result['boxes'] + det_result.scores = result['scores'] + det_result.label_ids = result['label_ids'] + det_result.masks = masks + det_result.contain_masks = result['contain_masks'] + return det_result + + +def json_to_classify(result): + cls_result = C.vision.ClassifyResult() + cls_result.label_ids = result['label_ids'] + cls_result.scores = result['scores'] + return cls_result + + +def json_to_keypoint(result): + kp_result = C.vision.KeyPointDetectionResult() + kp_result.keypoints = result['keypoints'] + kp_result.scores = result['scores'] + kp_result.num_joints = result['num_joints'] + return kp_result + + +def json_to_ocr(result): + ocr_result = C.vision.OCRResult() + ocr_result.boxes = result['boxes'] + ocr_result.text = result['text'] + ocr_result.rec_scores = result['rec_scores'] + ocr_result.cls_scores = result['cls_scores'] + ocr_result.cls_labels = result['cls_labels'] + return ocr_result + + +def json_to_mot(result): + mot_result = C.vision.MOTResult() + mot_result.boxes = result['boxes'] + mot_result.ids = result['ids'] + mot_result.scores = result['scores'] + mot_result.class_ids = result['class_ids'] + return mot_result + + +def json_to_face_detection(result): + face_result = C.vision.FaceDetectionResult() + face_result.boxes = result['boxes'] + face_result.landmarks = result['landmarks'] + face_result.scores = result['scores'] + face_result.landmarks_per_face = result['landmarks_per_face'] + return face_result + + +def json_to_face_alignment(result): + face_result = C.vision.FaceAlignmentResult() + face_result.landmarks = result['landmarks'] + return face_result + + +def json_to_face_recognition(result): + face_result = C.vision.FaceRecognitionResult() + face_result.embedding = result['embedding'] + return face_result + + +def json_to_segmentation(result): + seg_result = C.vision.SegmentationResult() + seg_result.label_map = result['label_map'] + seg_result.score_map = result['score_map'] + seg_result.shape = result['shape'] + seg_result.contain_score_map = result['contain_score_map'] + return seg_result + + +def json_to_matting(result): + matting_result = C.vision.MattingResult() + matting_result.alpha = result['alpha'] + matting_result.foreground = result['foreground'] + matting_result.shape = result['shape'] + matting_result.contain_foreground = result['contain_foreground'] + return matting_result + + +def json_to_head_pose(result): + hp_result = C.vision.HeadPoseResult() + hp_result.euler_angles = result['euler_angles'] + return hp_result diff --git a/python/requirements.txt b/python/requirements.txt index 2e5fa136c..3fa7f18d1 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -5,3 +5,4 @@ numpy opencv-python fastdeploy-tools==0.0.1 pyyaml +fastapi