diff --git a/.gitignore b/.gitignore index 04c0f97b8..db38f7705 100644 --- a/.gitignore +++ b/.gitignore @@ -13,12 +13,14 @@ fastdeploy/version.py fastdeploy/core/config.h python/fastdeploy/c_lib_wrap.py python/fastdeploy/LICENSE* +python/build_cpu.sh python/fastdeploy/ThirdPartyNotices* *.so* fpython/astdeploy/libs/third_libs fastdeploy/core/config.h fastdeploy/pybind/main.cc python/fastdeploy/libs/lib* +python/fastdeploy/libs/third_libs __pycache__ build_fd_android.sh python/scripts/process_libraries.py diff --git a/benchmark/run_benchmark_yolo.sh b/benchmark/run_benchmark_yolo.sh old mode 100644 new mode 100755 diff --git a/cmake/opencv.cmake b/cmake/opencv.cmake index 7557e682c..4b2f09293 100644 --- a/cmake/opencv.cmake +++ b/cmake/opencv.cmake @@ -186,6 +186,6 @@ else() endif() find_package(OpenCV REQUIRED PATHS ${OpenCV_DIR}) include_directories(${OpenCV_INCLUDE_DIRS}) - list(APPEND DEPEND_LIBS opencv_core opencv_highgui opencv_imgproc opencv_imgcodecs) + list(APPEND DEPEND_LIBS opencv_core opencv_video opencv_highgui opencv_imgproc opencv_imgcodecs) endif() endif() diff --git a/docs/api/vision_results/README.md b/docs/api/vision_results/README.md index 3df91cc0e..d094c711c 100644 --- a/docs/api/vision_results/README.md +++ b/docs/api/vision_results/README.md @@ -2,13 +2,14 @@ FastDeploy根据视觉模型的任务类型,定义了不同的结构体(`fastdeploy/vision/common/result.h`)来表达模型预测结果,具体如下表所示 -| 结构体 | 文档 | 说明 | 相应模型 | -| :----- | :--- | :---- | :------- | -| ClassifyResult | [C++/Python文档](./classification_result.md) | 图像分类返回结果 | ResNet50、MobileNetV3等 | -| SegmentationResult | [C++/Python文档](./segmentation_result.md) | 图像分割返回结果 | PP-HumanSeg、PP-LiteSeg等 | -| DetectionResult | [C++/Python文档](./detection_result.md) | 目标检测返回结果 | PP-YOLOE、YOLOv7系列模型等 | -| FaceDetectionResult | [C++/Python文档](./face_detection_result.md) | 目标检测返回结果 | SCRFD、RetinaFace系列模型等 | -| KeyPointDetectionResult | [C++/Python文档](./keypointdetection_result.md) | 关键点检测返回结果 | PP-Tinypose系列模型等 | -| FaceRecognitionResult | [C++/Python文档](./face_recognition_result.md) | 目标检测返回结果 | ArcFace、CosFace系列模型等 | -| MattingResult | [C++/Python文档](./matting_result.md) | 目标检测返回结果 | MODNet系列模型等 | -| OCRResult | [C++/Python文档](./ocr_result.md) | 文本框检测,分类和文本识别返回结果 | OCR系列模型等 | +| 结构体 | 文档 | 说明 | 相应模型 | +|:------------------------|:----------------------------------------------|:------------------|:------------------------| +| ClassifyResult | [C++/Python文档](./classification_result.md) | 图像分类返回结果 | ResNet50、MobileNetV3等 | +| SegmentationResult | [C++/Python文档](./segmentation_result.md) | 图像分割返回结果 | PP-HumanSeg、PP-LiteSeg等 | +| DetectionResult | [C++/Python文档](./detection_result.md) | 目标检测返回结果 | PP-YOLOE、YOLOv7系列模型等 | +| FaceDetectionResult | [C++/Python文档](./face_detection_result.md) | 目标检测返回结果 | SCRFD、RetinaFace系列模型等 | +| KeyPointDetectionResult | [C++/Python文档](./keypointdetection_result.md) | 关键点检测返回结果 | PP-Tinypose系列模型等 | +| FaceRecognitionResult | [C++/Python文档](./face_recognition_result.md) | 目标检测返回结果 | ArcFace、CosFace系列模型等 | +| MattingResult | [C++/Python文档](./matting_result.md) | 目标检测返回结果 | MODNet系列模型等 | +| OCRResult | [C++/Python文档](./ocr_result.md) | 文本框检测,分类和文本识别返回结果 | OCR系列模型等 | +| MOTResult | [C++/Python文档](./mot_result.md) | 多目标跟踪返回结果 | pptracking系列模型等 | \ No newline at end of file diff --git a/docs/api/vision_results/mot_result.md b/docs/api/vision_results/mot_result.md new file mode 100644 index 000000000..0dd7cda71 --- /dev/null +++ b/docs/api/vision_results/mot_result.md @@ -0,0 +1,40 @@ +# MOTResult 多目标跟踪结果 + +MOTResult代码定义在`fastdeploy/vision/common/result.h`中,用于表明多目标跟踪中的检测出来的目标框、目标跟踪id、目标类别和目标置信度。 + +## C++ 定义 + +```c++ +fastdeploy::vision::MOTResult +``` + +```c++ +struct MOTResult{ + // left top right bottom + std::vector> boxes; + std::vector ids; + std::vector scores; + std::vector class_ids; + void Clear(); + std::string Str(); +}; +``` + +- **boxes**: 成员变量,表示单帧画面中检测出来的所有目标框坐标,`boxes.size()`表示框的个数,每个框以4个float数值依次表示xmin, ymin, xmax, ymax, 即左上角和右下角坐标 +- **ids**: 成员变量,表示单帧画面中所有目标的id,其元素个数与`boxes.size()`一致 +- **scores**: 成员变量,表示单帧画面检测出来的所有目标置信度,其元素个数与`boxes.size()`一致 +- **class_ids**: 成员变量,表示单帧画面出来的所有目标类别,其元素个数与`boxes.size()`一致 +- **Clear()**: 成员函数,用于清除结构体中存储的结果 +- **Str()**: 成员函数,将结构体中的信息以字符串形式输出(用于Debug) + +## Python 定义 + +```python +fastdeploy.vision.MOTResult +``` + +- **boxes**(list of list(float)): 成员变量,表示单帧画面中检测出来的所有目标框坐标。boxes是一个list,其每个元素为一个长度为4的list, 表示为一个框,每个框以4个float数值依次表示xmin, ymin, xmax, ymax, 即左上角和右下角坐标 +- **ids**(list of list(float)):成员变量,表示单帧画面中所有目标的id,其元素个数与`boxes`一致 +- **scores**(list of float): 成员变量,表示单帧画面检测出来的所有目标置信度 +- **class_ids**(list of int): 成员变量,表示单帧画面出来的所有目标类别 + diff --git a/examples/vision/tracking/pptracking/cpp/CMakeLists.txt b/examples/vision/tracking/pptracking/cpp/CMakeLists.txt new file mode 100644 index 000000000..93540a7e8 --- /dev/null +++ b/examples/vision/tracking/pptracking/cpp/CMakeLists.txt @@ -0,0 +1,14 @@ +PROJECT(infer_demo C CXX) +CMAKE_MINIMUM_REQUIRED (VERSION 3.10) + +# 指定下载解压后的fastdeploy库路径 +option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.") + +include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake) + +# 添加FastDeploy依赖头文件 +include_directories(${FASTDEPLOY_INCS}) + +add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc) +# 添加FastDeploy库依赖 +target_link_libraries(infer_demo ${FASTDEPLOY_LIBS}) diff --git a/examples/vision/tracking/pptracking/cpp/README.md b/examples/vision/tracking/pptracking/cpp/README.md new file mode 100644 index 000000000..da0deb3f3 --- /dev/null +++ b/examples/vision/tracking/pptracking/cpp/README.md @@ -0,0 +1,79 @@ +# PP-Tracking C++部署示例 + +本目录下提供`infer.cc`快速完成PP-Tracking在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. 根据开发环境,下载预编译部署库和samples代码,参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +以Linux上 PP-Tracking 推理为例,在本目录执行如下命令即可完成编译测试(如若只需在CPU上部署,可在[Fastdeploy C++预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md/CPP_prebuilt_libraries.md)下载CPU推理库) + +```bash +#下载SDK,编译模型examples代码(SDK中包含了examples代码) +wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.3.0.tgz +tar xvf fastdeploy-linux-x64-gpu-0.3.0.tgz +cd fastdeploy-linux-x64-gpu-0.3.0/examples/vision/tracking/pptracking/cpp/ +mkdir build && cd build +cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../../fastdeploy-linux-x64-gpu-0.3.0 +make -j + +# 下载PP-Tracking模型文件和测试视频 +wget https://bj.bcebos.com/paddlehub/fastdeploy/fairmot_hrnetv2_w18_dlafpn_30e_576x320.tgz +tar -xvf fairmot_hrnetv2_w18_dlafpn_30e_576x320.tgz +wget https://bj.bcebos.com/paddlehub/fastdeploy/person.mp4 +wget https://bj.bcebos.com/paddlehub/fastdeploy/person.mp4 + + +# CPU推理 +./infer_demo fairmot_hrnetv2_w18_dlafpn_30e_576x320 person.mp4 0 +# GPU推理 +./infer_demo fairmot_hrnetv2_w18_dlafpn_30e_576x320 person.mp4 1 +# GPU上TensorRT推理 +./infer_demo fairmot_hrnetv2_w18_dlafpn_30e_576x320 person.mp4 2 +``` + +以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考: +- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md) + +## PP-Tracking C++接口 + +### PPTracking类 + +```c++ +fastdeploy::vision::tracking::PPTracking( + const string& model_file, + const string& params_file = "", + const string& config_file, + const RuntimeOption& runtime_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE) +``` + +PP-Tracking模型加载和初始化,其中model_file为导出的Paddle模型格式。 + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **config_file**(str): 推理部署配置文件 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式 + +#### Predict函数 + +> ```c++ +> PPTracking::Predict(cv::Mat* im, MOTResult* result) +> ``` +> +> 模型预测接口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **im**: 输入图像,注意需为HWC,BGR格式 +> > * **result**: 检测结果,包括检测框,跟踪id,各个框的置信度,对象类别id,MOTResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/) + + +- [模型介绍](../../) +- [Python部署](../python) +- [视觉模型预测结果](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/tracking/pptracking/cpp/infer.cc b/examples/vision/tracking/pptracking/cpp/infer.cc new file mode 100644 index 000000000..709159eb4 --- /dev/null +++ b/examples/vision/tracking/pptracking/cpp/infer.cc @@ -0,0 +1,158 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision.h" + +#ifdef WIN32 +const char sep = '\\'; +#else +const char sep = '/'; +#endif + +void CpuInfer(const std::string& model_dir, const std::string& video_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + auto model = fastdeploy::vision::tracking::PPTracking( + model_file, params_file, config_file); + + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + fastdeploy::vision::MOTResult result; + cv::Mat frame; + int frame_id=0; + cv::VideoCapture capture(video_file); + // according to the time of prediction to calculate fps + float fps= 0.0f; + while (capture.read(frame)) { + if (frame.empty()) { + break; + } + if (!model.Predict(&frame, &result)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + // std::cout << result.Str() << std::endl; + cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, fps , frame_id); + cv::imshow("mot",out_img); + cv::waitKey(30); + frame_id++; + } + capture.release(); + cv::destroyAllWindows(); +} + +void GpuInfer(const std::string& model_dir, const std::string& video_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + auto model = fastdeploy::vision::tracking::PPTracking( + model_file, params_file, config_file, option); + + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + fastdeploy::vision::MOTResult result; + cv::Mat frame; + int frame_id=0; + cv::VideoCapture capture(video_file); + // according to the time of prediction to calculate fps + float fps= 0.0f; + while (capture.read(frame)) { + if (frame.empty()) { + break; + } + if (!model.Predict(&frame, &result)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + // std::cout << result.Str() << std::endl; + cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, fps , frame_id); + cv::imshow("mot",out_img); + cv::waitKey(30); + frame_id++; + } + capture.release(); + cv::destroyAllWindows(); +} + +void TrtInfer(const std::string& model_dir, const std::string& video_file) { + auto model_file = model_dir + sep + "model.pdmodel"; + auto params_file = model_dir + sep + "model.pdiparams"; + auto config_file = model_dir + sep + "infer_cfg.yml"; + + auto option = fastdeploy::RuntimeOption(); + option.UseGpu(); + option.UseTrtBackend(); + auto model = fastdeploy::vision::tracking::PPTracking( + model_file, params_file, config_file, option); + + if (!model.Initialized()) { + std::cerr << "Failed to initialize." << std::endl; + return; + } + + fastdeploy::vision::MOTResult result; + cv::Mat frame; + int frame_id=0; + cv::VideoCapture capture(video_file); + // according to the time of prediction to calculate fps + float fps= 0.0f; + while (capture.read(frame)) { + if (frame.empty()) { + break; + } + if (!model.Predict(&frame, &result)) { + std::cerr << "Failed to predict." << std::endl; + return; + } + // std::cout << result.Str() << std::endl; + cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, fps , frame_id); + cv::imshow("mot",out_img); + cv::waitKey(30); + frame_id++; + } + capture.release(); + cv::destroyAllWindows(); +} + +int main(int argc, char* argv[]) { + if (argc < 4) { + std::cout + << "Usage: infer_demo path/to/model_dir path/to/video run_option, " + "e.g ./infer_model ./pptracking_model_dir ./person.mp4 0" + << std::endl; + std::cout << "The data type of run_option is int, 0: run with cpu; 1: run " + "with gpu; 2: run with gpu and use tensorrt backend." + << std::endl; + return -1; + } + + if (std::atoi(argv[3]) == 0) { + CpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 1) { + GpuInfer(argv[1], argv[2]); + } else if (std::atoi(argv[3]) == 2) { + TrtInfer(argv[1], argv[2]); + } + return 0; +} diff --git a/examples/vision/tracking/pptracking/python/README.md b/examples/vision/tracking/pptracking/python/README.md new file mode 100644 index 000000000..d3b943759 --- /dev/null +++ b/examples/vision/tracking/pptracking/python/README.md @@ -0,0 +1,70 @@ +# PP-Tracking Python部署示例 + +在部署前,需确认以下两个步骤 + +- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) +- 2. FastDeploy Python whl包安装,参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md) + +本目录下提供`infer.py`快速完成PP-Tracking在CPU/GPU,以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成 + +```bash +#下载部署示例代码 +git clone https://github.com/PaddlePaddle/FastDeploy.git +cd FastDeploy/examples/vision/tracking/pptracking/python + +# 下载PP-Tracking模型文件和测试视频 +wget https://bj.bcebos.com/paddlehub/fastdeploy/fairmot_hrnetv2_w18_dlafpn_30e_576x320.tgz +tar -xvf fairmot_hrnetv2_w18_dlafpn_30e_576x320.tgz +wget https://bj.bcebos.com/paddlehub/fastdeploy/person.mp4 +# CPU推理 +python infer.py --model fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video person.mp4 --device cpu +# GPU推理 +python infer.py --model fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video person.mp4 --device gpu +# GPU上使用TensorRT推理 (注意:TensorRT推理第一次运行,有序列化模型的操作,有一定耗时,需要耐心等待) +python infer.py --model fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video person.mp4 --device gpu --use_trt True +``` + +## PP-Tracking Python接口 + +```python +fd.vision.tracking.PPTracking(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE) +``` + +PP-Tracking模型加载和初始化,其中model_file, params_file以及config_file为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting) + +**参数** + +> * **model_file**(str): 模型文件路径 +> * **params_file**(str): 参数文件路径 +> * **config_file**(str): 推理部署配置文件 +> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 +> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式 + +### predict函数 + +> ```python +> PPTracking.predict(frame) +> ``` +> +> 模型预测结口,输入图像直接输出检测结果。 +> +> **参数** +> +> > * **frame**(np.ndarray): 输入数据,注意需为HWC,BGR格式,frame为视频帧如:_,frame=cap.read()得到 + +> **返回** +> +> > 返回`fastdeploy.vision.MOTResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/) + +### 类成员属性 +#### 预处理参数 +用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果 + + + +## 其它文档 + +- [PP-Tracking 模型介绍](..) +- [PP-Tracking C++部署](../cpp) +- [模型预测结果说明](../../../../../docs/api/vision_results/) +- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md) diff --git a/examples/vision/tracking/pptracking/python/infer.py b/examples/vision/tracking/pptracking/python/infer.py new file mode 100644 index 000000000..39681e7e5 --- /dev/null +++ b/examples/vision/tracking/pptracking/python/infer.py @@ -0,0 +1,79 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import time +import os + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", required=True, help="Path of PaddleSeg model.") + parser.add_argument( + "--video", type=str, required=True, help="Path of test video file.") + parser.add_argument( + "--device", + type=str, + default='cpu', + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--use_trt", + type=ast.literal_eval, + default=False, + help="Wether to use tensorrt.") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + + if args.device.lower() == "gpu": + option.use_gpu() + + if args.use_trt: + option.use_trt_backend() + return option + + +args = parse_arguments() + +# 配置runtime,加载模型 +runtime_option = build_option(args) +model_file = os.path.join(args.model, "model.pdmodel") +params_file = os.path.join(args.model, "model.pdiparams") +config_file = os.path.join(args.model, "infer_cfg.yml") +model = fd.vision.tracking.PPTracking( + model_file, params_file, config_file, runtime_option=runtime_option) + +# 预测图片分割结果 +cap = cv2.VideoCapture(args.video) +frame_id = 0 +while True: + start_time = time.time() + frame_id = frame_id+1 + _, frame = cap.read() + if frame is None: + break + result = model.predict(frame) + end_time = time.time() + fps = 1.0/(end_time-start_time) + img = fd.vision.vis_mot(frame, result, fps, frame_id) + cv2.imshow("video", img) + cv2.waitKey(30) +cap.release() +cv2.destroyAllWindows() diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index b83fb0f3d..10d69c458 100755 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -48,6 +48,7 @@ #include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h" #include "fastdeploy/vision/ocr/ppocr/recognizer.h" #include "fastdeploy/vision/segmentation/ppseg/model.h" +#include "fastdeploy/vision/tracking/pptracking/model.h" #endif #include "fastdeploy/vision/visualize/visualize.h" diff --git a/fastdeploy/vision/common/result.cc b/fastdeploy/vision/common/result.cc index c2e770083..956a6fd70 100644 --- a/fastdeploy/vision/common/result.cc +++ b/fastdeploy/vision/common/result.cc @@ -148,6 +148,26 @@ void OCRResult::Clear() { cls_labels.clear(); } +void MOTResult::Clear(){ + boxes.clear(); + ids.clear(); + scores.clear(); + class_ids.clear(); +} + +std::string MOTResult::Str(){ + std::string out; + out = "MOTResult:\nall boxes counts: "+std::to_string(boxes.size())+"\n"; + out += "[xmin\tymin\txmax\tymax\tid\tscore]\n"; + for (size_t i = 0; i < boxes.size(); ++i) { + out = out + "["+ std::to_string(boxes[i][0]) + "\t" + + std::to_string(boxes[i][1]) + "\t" + std::to_string(boxes[i][2]) + + "\t" + std::to_string(boxes[i][3]) + "\t" + + std::to_string(ids[i]) + "\t" + std::to_string(scores[i]) + "]\n"; + } + return out; +} + FaceDetectionResult::FaceDetectionResult(const FaceDetectionResult& res) { boxes.assign(res.boxes.begin(), res.boxes.end()); landmarks.assign(res.landmarks.begin(), res.landmarks.end()); diff --git a/fastdeploy/vision/common/result.h b/fastdeploy/vision/common/result.h index 2cc494240..c9a8d113a 100644 --- a/fastdeploy/vision/common/result.h +++ b/fastdeploy/vision/common/result.h @@ -26,6 +26,7 @@ enum FASTDEPLOY_DECL ResultType { DETECTION, SEGMENTATION, OCR, + MOT, FACE_DETECTION, FACE_RECOGNITION, MATTING, @@ -154,6 +155,21 @@ struct FASTDEPLOY_DECL OCRResult : public BaseResult { std::string Str(); }; +struct FASTDEPLOY_DECL MOTResult : public BaseResult { + // left top right bottom + std::vector> boxes; + std::vector ids; + std::vector scores; + std::vector class_ids; + ResultType type = ResultType::MOT; + + void Clear(); + + std::string Str(); +}; + + + /*! @brief Face detection result structure for all the face detection models */ struct FASTDEPLOY_DECL FaceDetectionResult : public BaseResult { @@ -268,5 +284,6 @@ struct FASTDEPLOY_DECL MattingResult : public BaseResult { std::string Str(); }; + } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/facedet/contrib/scrfd.cc b/fastdeploy/vision/facedet/contrib/scrfd.cc index d87d66068..7d6974410 100644 --- a/fastdeploy/vision/facedet/contrib/scrfd.cc +++ b/fastdeploy/vision/facedet/contrib/scrfd.cc @@ -63,7 +63,7 @@ SCRFD::SCRFD(const std::string& model_file, const std::string& params_file, const RuntimeOption& custom_option, const ModelFormat& model_format) { if (model_format == ModelFormat::ONNX) { - valid_cpu_backends = {Backend::ORT}; + valid_cpu_backends = {Backend::ORT}; valid_gpu_backends = {Backend::ORT, Backend::TRT}; } else { valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; diff --git a/fastdeploy/vision/tracking/pptracking/lapjv.cc b/fastdeploy/vision/tracking/pptracking/lapjv.cc new file mode 100644 index 000000000..32546b11c --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/lapjv.cc @@ -0,0 +1,413 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The code is based on: +// https://github.com/gatagat/lap/blob/master/lap/lapjv.cpp +// Ths copyright of gatagat/lap is as follows: +// MIT License + +#include +#include +#include + +#include "fastdeploy/vision/tracking/pptracking/lapjv.h" + +namespace fastdeploy { +namespace vision { +namespace tracking { + +/** Column-reduction and reduction transfer for a dense cost matrix. + */ +int _ccrrt_dense( + const int n, float *cost[], int *free_rows, int *x, int *y, float *v) { + int n_free_rows; + bool *unique; + + for (int i = 0; i < n; i++) { + x[i] = -1; + v[i] = LARGE; + y[i] = 0; + } + for (int i = 0; i < n; i++) { + for (int j = 0; j < n; j++) { + const float c = cost[i][j]; + if (c < v[j]) { + v[j] = c; + y[j] = i; + } + } + } + NEW(unique, bool, n); + memset(unique, TRUE, n); + { + int j = n; + do { + j--; + const int i = y[j]; + if (x[i] < 0) { + x[i] = j; + } else { + unique[i] = FALSE; + y[j] = -1; + } + } while (j > 0); + } + n_free_rows = 0; + for (int i = 0; i < n; i++) { + if (x[i] < 0) { + free_rows[n_free_rows++] = i; + } else if (unique[i]) { + const int j = x[i]; + float min = LARGE; + for (int j2 = 0; j2 < n; j2++) { + if (j2 == static_cast(j)) { + continue; + } + const float c = cost[i][j2] - v[j2]; + if (c < min) { + min = c; + } + } + v[j] -= min; + } + } + FREE(unique); + return n_free_rows; +} + +/** Augmenting row reduction for a dense cost matrix. + */ +int _carr_dense(const int n, + float *cost[], + const int n_free_rows, + int *free_rows, + int *x, + int *y, + float *v) { + int current = 0; + int new_free_rows = 0; + int rr_cnt = 0; + while (current < n_free_rows) { + int i0; + int j1, j2; + float v1, v2, v1_new; + bool v1_lowers; + + rr_cnt++; + const int free_i = free_rows[current++]; + j1 = 0; + v1 = cost[free_i][0] - v[0]; + j2 = -1; + v2 = LARGE; + for (int j = 1; j < n; j++) { + const float c = cost[free_i][j] - v[j]; + if (c < v2) { + if (c >= v1) { + v2 = c; + j2 = j; + } else { + v2 = v1; + v1 = c; + j2 = j1; + j1 = j; + } + } + } + i0 = y[j1]; + v1_new = v[j1] - (v2 - v1); + v1_lowers = v1_new < v[j1]; + if (rr_cnt < current * n) { + if (v1_lowers) { + v[j1] = v1_new; + } else if (i0 >= 0 && j2 >= 0) { + j1 = j2; + i0 = y[j2]; + } + if (i0 >= 0) { + if (v1_lowers) { + free_rows[--current] = i0; + } else { + free_rows[new_free_rows++] = i0; + } + } + } else { + if (i0 >= 0) { + free_rows[new_free_rows++] = i0; + } + } + x[free_i] = j1; + y[j1] = free_i; + } + return new_free_rows; +} + +/** Find columns with minimum d[j] and put them on the SCAN list. + */ +int _find_dense(const int n, int lo, float *d, int *cols, int *y) { + int hi = lo + 1; + float mind = d[cols[lo]]; + for (int k = hi; k < n; k++) { + int j = cols[k]; + if (d[j] <= mind) { + if (d[j] < mind) { + hi = lo; + mind = d[j]; + } + cols[k] = cols[hi]; + cols[hi++] = j; + } + } + return hi; +} + +// Scan all columns in TODO starting from arbitrary column in SCAN +// and try to decrease d of the TODO columns using the SCAN column. +int _scan_dense(const int n, + float *cost[], + int *plo, + int *phi, + float *d, + int *cols, + int *pred, + int *y, + float *v) { + int lo = *plo; + int hi = *phi; + float h, cred_ij; + + while (lo != hi) { + int j = cols[lo++]; + const int i = y[j]; + const float mind = d[j]; + h = cost[i][j] - v[j] - mind; + // For all columns in TODO + for (int k = hi; k < n; k++) { + j = cols[k]; + cred_ij = cost[i][j] - v[j] - h; + if (cred_ij < d[j]) { + d[j] = cred_ij; + pred[j] = i; + if (cred_ij == mind) { + if (y[j] < 0) { + return j; + } + cols[k] = cols[hi]; + cols[hi++] = j; + } + } + } + } + *plo = lo; + *phi = hi; + return -1; +} + +/** Single iteration of modified Dijkstra shortest path algorithm as explained + * in the JV paper. + * + * This is a dense matrix version. + * + * \return The closest free column index. + */ +int find_path_dense(const int n, + float *cost[], + const int start_i, + int *y, + float *v, + int *pred) { + int lo = 0, hi = 0; + int final_j = -1; + int n_ready = 0; + int *cols; + float *d; + + NEW(cols, int, n); + NEW(d, float, n); + + for (int i = 0; i < n; i++) { + cols[i] = i; + pred[i] = start_i; + d[i] = cost[start_i][i] - v[i]; + } + while (final_j == -1) { + // No columns left on the SCAN list. + if (lo == hi) { + n_ready = lo; + hi = _find_dense(n, lo, d, cols, y); + for (int k = lo; k < hi; k++) { + const int j = cols[k]; + if (y[j] < 0) { + final_j = j; + } + } + } + if (final_j == -1) { + final_j = _scan_dense(n, cost, &lo, &hi, d, cols, pred, y, v); + } + } + + { + const float mind = d[cols[lo]]; + for (int k = 0; k < n_ready; k++) { + const int j = cols[k]; + v[j] += d[j] - mind; + } + } + + FREE(cols); + FREE(d); + + return final_j; +} + +/** Augment for a dense cost matrix. + */ +int _ca_dense(const int n, + float *cost[], + const int n_free_rows, + int *free_rows, + int *x, + int *y, + float *v) { + int *pred; + + NEW(pred, int, n); + + for (int *pfree_i = free_rows; pfree_i < free_rows + n_free_rows; pfree_i++) { + int i = -1, j; + int k = 0; + + j = find_path_dense(n, cost, *pfree_i, y, v, pred); + while (i != *pfree_i) { + i = pred[j]; + y[j] = i; + SWAP_INDICES(j, x[i]); + k++; + } + } + FREE(pred); + return 0; +} + +/** Solve dense sparse LAP. + */ +int lapjv_internal(const cv::Mat &cost, + const bool extend_cost, + const float cost_limit, + int *x, + int *y) { + int n_rows = cost.rows; + int n_cols = cost.cols; + int n; + if (n_rows == n_cols) { + n = n_rows; + } else if (!extend_cost) { + throw std::invalid_argument( + "Square cost array expected. If cost is intentionally non-square, pass " + "extend_cost=True."); + } + + // Get extend cost + if (extend_cost || cost_limit < LARGE) { + n = n_rows + n_cols; + } + cv::Mat cost_expand(n, n, CV_32F); + float expand_value; + if (cost_limit < LARGE) { + expand_value = cost_limit / 2; + } else { + double max_v; + minMaxLoc(cost, nullptr, &max_v); + expand_value = static_cast(max_v) + 1.; + } + + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + cost_expand.at(i, j) = expand_value; + if (i >= n_rows && j >= n_cols) { + cost_expand.at(i, j) = 0; + } else if (i < n_rows && j < n_cols) { + cost_expand.at(i, j) = cost.at(i, j); + } + } + } + + // Convert Mat to pointer array + float **cost_ptr; + NEW(cost_ptr, float *, n); + for (int i = 0; i < n; ++i) { + NEW(cost_ptr[i], float, n); + } + for (int i = 0; i < n; ++i) { + for (int j = 0; j < n; ++j) { + cost_ptr[i][j] = cost_expand.at(i, j); + } + } + + int ret; + int *free_rows; + float *v; + int *x_c; + int *y_c; + + NEW(free_rows, int, n); + NEW(v, float, n); + NEW(x_c, int, n); + NEW(y_c, int, n); + + ret = _ccrrt_dense(n, cost_ptr, free_rows, x_c, y_c, v); + int i = 0; + while (ret > 0 && i < 2) { + ret = _carr_dense(n, cost_ptr, ret, free_rows, x_c, y_c, v); + i++; + } + if (ret > 0) { + ret = _ca_dense(n, cost_ptr, ret, free_rows, x_c, y_c, v); + } + FREE(v); + FREE(free_rows); + for (int i = 0; i < n; ++i) { + FREE(cost_ptr[i]); + } + FREE(cost_ptr); + if (ret != 0) { + if (ret == -1) { + throw "Out of memory."; + } + throw "Unknown error (lapjv_internal)"; + } + // Get output of x, y, opt + for (int i = 0; i < n; ++i) { + if (i < n_rows) { + x[i] = x_c[i]; + if (x[i] >= n_cols) { + x[i] = -1; + } + } + if (i < n_cols) { + y[i] = y_c[i]; + if (y[i] >= n_rows) { + y[i] = -1; + } + } + } + + FREE(x_c); + FREE(y_c); + return ret; +} + +} // namespace tracking +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/tracking/pptracking/lapjv.h b/fastdeploy/vision/tracking/pptracking/lapjv.h new file mode 100644 index 000000000..b8e122204 --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/lapjv.h @@ -0,0 +1,66 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The code is based on: +// https://github.com/gatagat/lap/blob/master/lap/lapjv.h +// Ths copyright of gatagat/lap is as follows: +// MIT License + +#pragma once +#define LARGE 1000000 + +#if !defined TRUE +#define TRUE 1 +#endif +#if !defined FALSE +#define FALSE 0 +#endif + +#define NEW(x, t, n) \ + if ((x = reinterpret_cast(malloc(sizeof(t) * (n)))) == 0) { \ + return -1; \ + } +#define FREE(x) \ + if (x != 0) { \ + free(x); \ + x = 0; \ + } +#define SWAP_INDICES(a, b) \ + { \ + int_t _temp_index = a; \ + a = b; \ + b = _temp_index; \ + } +#include + +namespace fastdeploy { +namespace vision { +namespace tracking { + +typedef signed int int_t; +typedef unsigned int uint_t; +typedef double cost_t; +typedef char boolean; +typedef enum fp_t { FP_1 = 1, FP_2 = 2, FP_DYNAMIC = 3 } fp_t; + +int lapjv_internal(const cv::Mat &cost, + const bool extend_cost, + const float cost_limit, + int *x, + int *y); + +} // namespace tracking +} // namespace vision +} // namespace fastdeploy + diff --git a/fastdeploy/vision/tracking/pptracking/letter_box.cc b/fastdeploy/vision/tracking/pptracking/letter_box.cc new file mode 100644 index 000000000..bc8e0cbcb --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/letter_box.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/tracking/pptracking/letter_box.h" + +namespace fastdeploy { +namespace vision { +namespace tracking { + +LetterBoxResize::LetterBoxResize(const std::vector& target_size, const std::vector& color){ + target_size_=target_size; + color_=color; +} +bool LetterBoxResize::ImplByOpenCV(Mat* mat){ + if (mat->Channels() != color_.size()) { + FDERROR << "Pad: Require input channels equals to size of padding value, " + "but now channels = " + << mat->Channels() + << ", the size of padding values = " << color_.size() << "." + << std::endl; + return false; + } + // generate scale_factor + int origin_w = mat->Width(); + int origin_h = mat->Height(); + int target_h = target_size_[0]; + int target_w = target_size_[1]; + float ratio_h = static_cast(target_h) / static_cast(origin_h); + float ratio_w = static_cast(target_w) / static_cast(origin_w); + float resize_scale = std::min(ratio_h, ratio_w); + + int new_shape_w = std::round(mat->Width() * resize_scale); + int new_shape_h = std::round(mat->Height() * resize_scale); + float padw = (target_size_[1] - new_shape_w) / 2.; + float padh = (target_size_[0] - new_shape_h) / 2.; + int top = std::round(padh - 0.1); + int bottom = std::round(padh + 0.1); + int left = std::round(padw - 0.1); + int right = std::round(padw + 0.1); + + Resize::Run(mat,new_shape_w,new_shape_h); + Pad::Run(mat,top,bottom,left,right,color_); + return true; +} + +} // namespace tracking +} // namespace vision +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/tracking/pptracking/letter_box.h b/fastdeploy/vision/tracking/pptracking/letter_box.h new file mode 100644 index 000000000..17f9f2833 --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/letter_box.h @@ -0,0 +1,36 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/fastdeploy_model.h" +namespace fastdeploy { +namespace vision { +namespace tracking { +class LetterBoxResize: public Processor{ +public: + LetterBoxResize(const std::vector& target_size, const std::vector& color); + bool ImplByOpenCV(Mat* mat) override; + std::string Name() override { return "LetterBoxResize"; } +private: + std::vector target_size_; + std::vector color_; +}; + +} // namespace tracking +} // namespace vision +} // namespace fastdeploy + + diff --git a/fastdeploy/vision/tracking/pptracking/model.cc b/fastdeploy/vision/tracking/pptracking/model.cc new file mode 100644 index 000000000..b4915f8a3 --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/model.cc @@ -0,0 +1,331 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/tracking/pptracking/model.h" +#include "yaml-cpp/yaml.h" +#include "paddle2onnx/converter.h" + +namespace fastdeploy { +namespace vision { +namespace tracking { + +PPTracking::PPTracking(const std::string& model_file, + const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option, + const ModelFormat& model_format){ + config_file_=config_file; + valid_cpu_backends = {Backend::PDINFER, Backend::ORT}; + valid_gpu_backends = {Backend::PDINFER, Backend::ORT}; + + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + + initialized = Initialize(); +} + +bool PPTracking::BuildPreprocessPipelineFromConfig(){ + processors_.clear(); + YAML::Node cfg; + try { + cfg = YAML::LoadFile(config_file_); + } catch (YAML::BadFile& e) { + FDERROR << "Failed to load yaml file " << config_file_ + << ", maybe you should check this file." << std::endl; + return false; + } + + // Get draw_threshold for visualization + if (cfg["draw_threshold"].IsDefined()) { + draw_threshold_ = cfg["draw_threshold"].as(); + } else { + FDERROR << "Please set draw_threshold." << std::endl; + return false; + } + // Get config for tracker + if (cfg["tracker"].IsDefined()) { + if (cfg["tracker"]["conf_thres"].IsDefined()) { + conf_thresh_ = cfg["tracker"]["conf_thres"].as(); + } + else { + std::cerr << "Please set conf_thres in tracker." << std::endl; + return false; + } + if (cfg["tracker"]["min_box_area"].IsDefined()) { + min_box_area_ = cfg["tracker"]["min_box_area"].as(); + } + if (cfg["tracker"]["tracked_thresh"].IsDefined()) { + tracked_thresh_ = cfg["tracker"]["tracked_thresh"].as(); + } + } + + processors_.push_back(std::make_shared()); + for (const auto& op : cfg["Preprocess"]) { + std::string op_name = op["type"].as(); + if (op_name == "Resize") { + bool keep_ratio = op["keep_ratio"].as(); + auto target_size = op["target_size"].as>(); + int interp = op["interp"].as(); + FDASSERT(target_size.size() == 2, + "Require size of target_size be 2, but now it's %lu.", + target_size.size()); + if (!keep_ratio) { + int width = target_size[1]; + int height = target_size[0]; + processors_.push_back( + std::make_shared(width, height, -1.0, -1.0, interp, false)); + } else { + int min_target_size = std::min(target_size[0], target_size[1]); + int max_target_size = std::max(target_size[0], target_size[1]); + std::vector max_size; + if (max_target_size > 0) { + max_size.push_back(max_target_size); + max_size.push_back(max_target_size); + } + processors_.push_back(std::make_shared( + min_target_size, interp, true, max_size)); + } + + } + else if(op_name == "LetterBoxResize"){ + auto target_size = op["target_size"].as>(); + FDASSERT(target_size.size() == 2,"Require size of target_size be 2, but now it's %lu.", + target_size.size()); + std::vector color{127.0f,127.0f,127.0f}; + if (op["fill_value"].IsDefined()){ + color =op["fill_value"].as>(); + } + processors_.push_back(std::make_shared(target_size, color)); + } + else if (op_name == "NormalizeImage") { + auto mean = op["mean"].as>(); + auto std = op["std"].as>(); + bool is_scale = true; + if (op["is_scale"]) { + is_scale = op["is_scale"].as(); + } + std::string norm_type = "mean_std"; + if (op["norm_type"]) { + norm_type = op["norm_type"].as(); + } + if (norm_type != "mean_std") { + std::fill(mean.begin(), mean.end(), 0.0); + std::fill(std.begin(), std.end(), 1.0); + } + processors_.push_back(std::make_shared(mean, std, is_scale)); + } + else if (op_name == "Permute") { + // Do nothing, do permute as the last operation + continue; + // processors_.push_back(std::make_shared()); + } else if (op_name == "Pad") { + auto size = op["size"].as>(); + auto value = op["fill_value"].as>(); + processors_.push_back(std::make_shared("float")); + processors_.push_back( + std::make_shared(size[1], size[0], value)); + } else if (op_name == "PadStride") { + auto stride = op["stride"].as(); + processors_.push_back( + std::make_shared(stride, std::vector(3, 0))); + } else { + FDERROR << "Unexcepted preprocess operator: " << op_name << "." + << std::endl; + return false; + } + } + processors_.push_back(std::make_shared()); + return true; +} + +void PPTracking::GetNmsInfo() { + if (runtime_option.model_format == ModelFormat::PADDLE) { + std::string contents; + if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) { + return; + } + auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size()); + if (reader.has_nms) { + has_nms_ = true; + background_label = reader.nms_params.background_label; + keep_top_k = reader.nms_params.keep_top_k; + nms_eta = reader.nms_params.nms_eta; + nms_threshold = reader.nms_params.nms_threshold; + score_threshold = reader.nms_params.score_threshold; + nms_top_k = reader.nms_params.nms_top_k; + normalized = reader.nms_params.normalized; + } + } +} + +bool PPTracking::Initialize() { + // remove multiclass_nms3 now + // this is a trick operation for ppyoloe while inference on trt + GetNmsInfo(); + runtime_option.remove_multiclass_nms_ = true; + runtime_option.custom_op_info_["multiclass_nms3"] = "MultiClassNMS"; + if (!BuildPreprocessPipelineFromConfig()) { + FDERROR << "Failed to build preprocess pipeline from configuration file." + << std::endl; + return false; + } + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + + if (has_nms_ && runtime_option.backend == Backend::TRT) { + FDINFO << "Detected operator multiclass_nms3 in your model, will replace " + "it with fastdeploy::backend::MultiClassNMS(background_label=" + << background_label << ", keep_top_k=" << keep_top_k + << ", nms_eta=" << nms_eta << ", nms_threshold=" << nms_threshold + << ", score_threshold=" << score_threshold + << ", nms_top_k=" << nms_top_k << ", normalized=" << normalized + << ")." << std::endl; + has_nms_ = false; + } + + // create JDETracker instance + std::unique_ptr jdeTracker(new JDETracker); + jdeTracker_ = std::move(jdeTracker); + + return true; +} + +bool PPTracking::Predict(cv::Mat *img, MOTResult *result) { + Mat mat(*img); + std::vector input_tensors; + + if (!Preprocess(&mat, &input_tensors)) { + FDERROR << "Failed to preprocess input image." << std::endl; + return false; + } + std::vector output_tensors; + if (!Infer(input_tensors, &output_tensors)) { + FDERROR << "Failed to inference." << std::endl; + return false; + } + + if (!Postprocess(output_tensors, result)) { + FDERROR << "Failed to post process." << std::endl; + return false; + } + return true; +} + + +bool PPTracking::Preprocess(Mat* mat, std::vector* outputs) { + + int origin_w = mat->Width(); + int origin_h = mat->Height(); + + for (size_t i = 0; i < processors_.size(); ++i) { + if (!(*(processors_[i].get()))(mat)) { + FDERROR << "Failed to process image data in " << processors_[i]->Name() + << "." << std::endl; + return false; + } + } + +// LetterBoxResize(mat); +// Normalize::Run(mat,mean_,scale_,is_scale_); +// HWC2CHW::Run(mat); + Cast::Run(mat, "float"); + + outputs->resize(3); + // image_shape + (*outputs)[0].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(0).name); + float* shape = static_cast((*outputs)[0].MutableData()); + shape[0] = mat->Height(); + shape[1] = mat->Width(); + // image + (*outputs)[1].name = InputInfoOfRuntime(1).name; + mat->ShareWithTensor(&((*outputs)[1])); + (*outputs)[1].ExpandDim(0); + // scale + (*outputs)[2].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(2).name); + float* scale = static_cast((*outputs)[2].MutableData()); + scale[0] = mat->Height() * 1.0 / origin_h; + scale[1] = mat->Width() * 1.0 / origin_w; + return true; +} + + +void FilterDets(const float conf_thresh,const cv::Mat& dets,std::vector* index) { + for (int i = 0; i < dets.rows; ++i) { + float score = *dets.ptr(i, 4); + if (score > conf_thresh) { + index->push_back(i); + } + } +} + +bool PPTracking::Postprocess(std::vector& infer_result, MOTResult *result){ + auto bbox_shape = infer_result[0].shape; + auto bbox_data = static_cast(infer_result[0].Data()); + + auto emb_shape = infer_result[1].shape; + auto emb_data = static_cast(infer_result[1].Data()); + + cv::Mat dets(bbox_shape[0], 6, CV_32FC1, bbox_data); + cv::Mat emb(bbox_shape[0], emb_shape[1], CV_32FC1, emb_data); + + + result->Clear(); + std::vector tracks; + std::vector valid; + FilterDets(conf_thresh_, dets, &valid); + cv::Mat new_dets, new_emb; + for (int i = 0; i < valid.size(); ++i) { + new_dets.push_back(dets.row(valid[i])); + new_emb.push_back(emb.row(valid[i])); + } + jdeTracker_->update(new_dets, new_emb, &tracks); + if (tracks.size() == 0) { + std::array box={int(*dets.ptr(0, 0)), + int(*dets.ptr(0, 1)), + int(*dets.ptr(0, 2)), + int(*dets.ptr(0, 3))}; + result->boxes.push_back(box); + result->ids.push_back(1); + result->scores.push_back(*dets.ptr(0, 4)); + + } else { + std::vector::iterator titer; + for (titer = tracks.begin(); titer != tracks.end(); ++titer) { + if (titer->score < tracked_thresh_) { + continue; + } else { + float w = titer->ltrb[2] - titer->ltrb[0]; + float h = titer->ltrb[3] - titer->ltrb[1]; + bool vertical = w / h > 1.6; + float area = w * h; + if (area > min_box_area_ && !vertical) { + std::array box = { + int(titer->ltrb[0]), int(titer->ltrb[1]), int(titer->ltrb[2]), int(titer->ltrb[3])}; + result->boxes.push_back(box); + result->ids.push_back(titer->id); + result->scores.push_back(titer->score); + } + } + } + } + return true; +} + +} // namespace tracking +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/tracking/pptracking/model.h b/fastdeploy/vision/tracking/pptracking/model.h new file mode 100644 index 000000000..040ed383f --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/model.h @@ -0,0 +1,90 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "fastdeploy/vision/common/processors/transform.h" +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/result.h" +#include "fastdeploy/vision/tracking/pptracking/tracker.h" +#include "fastdeploy/vision/tracking/pptracking/letter_box.h" + +namespace fastdeploy { +namespace vision { +namespace tracking { + +class FASTDEPLOY_DECL PPTracking: public FastDeployModel { + +public: + /** \brief Set path of model file and configuration file, and the configuration of runtime + * + * \param[in] model_file Path of model file, e.g pptracking/model.pdmodel + * \param[in] params_file Path of parameter file, e.g pptracking/model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] config_file Path of configuration file for deployment, e.g pptracking/infer_cfg.yml + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends` + * \param[in] model_format Model format of the loaded model, default is Paddle format + */ + PPTracking(const std::string& model_file, + const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE); + + /// Get model's name + std::string ModelName() const override { return "pptracking"; } + + /** \brief Predict the detection result for an input image(consecutive) + * + * \param[in] im The input image data which is consecutive frame, comes from imread() or videoCapture.read() + * \param[in] result The output tracking result will be writen to this structure + * \return true if the prediction successed, otherwise false + */ + virtual bool Predict(cv::Mat* img, MOTResult* result); + + +private: + + bool BuildPreprocessPipelineFromConfig(); + bool Initialize(); + void GetNmsInfo(); + + bool Preprocess(Mat* img, std::vector* outputs); + + bool Postprocess(std::vector& infer_result, MOTResult *result); + + std::vector> processors_; + std::string config_file_; + float draw_threshold_; + float conf_thresh_; + float tracked_thresh_; + float min_box_area_; + bool is_scale_ = true; + std::unique_ptr jdeTracker_; + + // configuration for nms + int64_t background_label = -1; + int64_t keep_top_k = 300; + float nms_eta = 1.0; + float nms_threshold = 0.7; + float score_threshold = 0.01; + int64_t nms_top_k = 10000; + bool normalized = true; + bool has_nms_ = true; + +}; + +} // namespace tracking +} // namespace vision +} // namespace fastdeploy + diff --git a/fastdeploy/vision/tracking/pptracking/pptracking_pybind.cc b/fastdeploy/vision/tracking/pptracking/pptracking_pybind.cc new file mode 100644 index 000000000..d56437ad5 --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/pptracking_pybind.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { +void BindPPTracking(pybind11::module &m) { + pybind11::class_( + m, "PPTracking") + .def(pybind11::init()) + .def("predict", + [](vision::tracking::PPTracking &self, + pybind11::array &data) { + auto mat = PyArrayToCvMat(data); + vision::MOTResult *res = new vision::MOTResult(); + self.Predict(&mat, res); + return res; + }); +} +} // namespace fastdeploy diff --git a/fastdeploy/vision/tracking/pptracking/tracker.cc b/fastdeploy/vision/tracking/pptracking/tracker.cc new file mode 100644 index 000000000..0026d889c --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/tracker.cc @@ -0,0 +1,305 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The code is based on: +// https://github.com/CnybTseng/JDE/blob/master/platforms/common/jdetracker.cpp +// Ths copyright of CnybTseng/JDE is as follows: +// MIT License + +#include +#include +#include +#include + +#include "fastdeploy/vision/tracking/pptracking/lapjv.h" +#include "fastdeploy/vision/tracking/pptracking/tracker.h" + +#define mat2vec4f(m) \ + cv::Vec4f(*m.ptr(0, 0), \ + *m.ptr(0, 1), \ + *m.ptr(0, 2), \ + *m.ptr(0, 3)) + +namespace fastdeploy { +namespace vision { +namespace tracking { + +static std::map chi2inv95 = {{1, 3.841459f}, + {2, 5.991465f}, + {3, 7.814728f}, + {4, 9.487729f}, + {5, 11.070498f}, + {6, 12.591587f}, + {7, 14.067140f}, + {8, 15.507313f}, + {9, 16.918978f}}; + + +JDETracker::JDETracker() + : timestamp(0), max_lost_time(30), lambda(0.98f), det_thresh(0.3f) {} + +bool JDETracker::update(const cv::Mat &dets, + const cv::Mat &emb, + std::vector *tracks) { + ++timestamp; + TrajectoryPool candidates(dets.rows); + for (int i = 0; i < dets.rows; ++i) { + float score = *dets.ptr(i, 1); + const cv::Mat <rb_ = dets(cv::Rect(2, i, 4, 1)); + cv::Vec4f ltrb = mat2vec4f(ltrb_); + const cv::Mat &embedding = emb(cv::Rect(0, i, emb.cols, 1)); + candidates[i] = Trajectory(ltrb, score, embedding); + } + + TrajectoryPtrPool tracked_trajectories; + TrajectoryPtrPool unconfirmed_trajectories; + for (size_t i = 0; i < this->tracked_trajectories.size(); ++i) { + if (this->tracked_trajectories[i].is_activated) + tracked_trajectories.push_back(&this->tracked_trajectories[i]); + else + unconfirmed_trajectories.push_back(&this->tracked_trajectories[i]); + } + + TrajectoryPtrPool trajectory_pool = + tracked_trajectories + &(this->lost_trajectories); + + for (size_t i = 0; i < trajectory_pool.size(); ++i) + trajectory_pool[i]->predict(); + + Match matches; + std::vector mismatch_row; + std::vector mismatch_col; + + cv::Mat cost = motion_distance(trajectory_pool, candidates); + linear_assignment(cost, 0.7f, &matches, &mismatch_row, &mismatch_col); + + MatchIterator miter; + TrajectoryPtrPool activated_trajectories; + TrajectoryPtrPool retrieved_trajectories; + + for (miter = matches.begin(); miter != matches.end(); miter++) { + Trajectory *pt = trajectory_pool[miter->first]; + Trajectory &ct = candidates[miter->second]; + if (pt->state == Tracked) { + pt->update(&ct, timestamp); + activated_trajectories.push_back(pt); + } else { + pt->reactivate(&ct, count,timestamp); + retrieved_trajectories.push_back(pt); + } + } + + TrajectoryPtrPool next_candidates(mismatch_col.size()); + for (size_t i = 0; i < mismatch_col.size(); ++i) + next_candidates[i] = &candidates[mismatch_col[i]]; + + TrajectoryPtrPool next_trajectory_pool; + for (size_t i = 0; i < mismatch_row.size(); ++i) { + int j = mismatch_row[i]; + if (trajectory_pool[j]->state == Tracked) + next_trajectory_pool.push_back(trajectory_pool[j]); + } + + cost = iou_distance(next_trajectory_pool, next_candidates); + linear_assignment(cost, 0.5f, &matches, &mismatch_row, &mismatch_col); + + for (miter = matches.begin(); miter != matches.end(); miter++) { + Trajectory *pt = next_trajectory_pool[miter->first]; + Trajectory *ct = next_candidates[miter->second]; + if (pt->state == Tracked) { + pt->update(ct, timestamp); + activated_trajectories.push_back(pt); + } else { + pt->reactivate(ct,count, timestamp); + retrieved_trajectories.push_back(pt); + } + } + + TrajectoryPtrPool lost_trajectories; + for (size_t i = 0; i < mismatch_row.size(); ++i) { + Trajectory *pt = next_trajectory_pool[mismatch_row[i]]; + if (pt->state != Lost) { + pt->mark_lost(); + lost_trajectories.push_back(pt); + } + } + + TrajectoryPtrPool nnext_candidates(mismatch_col.size()); + for (size_t i = 0; i < mismatch_col.size(); ++i) + nnext_candidates[i] = next_candidates[mismatch_col[i]]; + cost = iou_distance(unconfirmed_trajectories, nnext_candidates); + linear_assignment(cost, 0.7f, &matches, &mismatch_row, &mismatch_col); + + for (miter = matches.begin(); miter != matches.end(); miter++) { + unconfirmed_trajectories[miter->first]->update( + nnext_candidates[miter->second], timestamp); + activated_trajectories.push_back(unconfirmed_trajectories[miter->first]); + } + + TrajectoryPtrPool removed_trajectories; + + for (size_t i = 0; i < mismatch_row.size(); ++i) { + unconfirmed_trajectories[mismatch_row[i]]->mark_removed(); + removed_trajectories.push_back(unconfirmed_trajectories[mismatch_row[i]]); + } + + for (size_t i = 0; i < mismatch_col.size(); ++i) { + if (nnext_candidates[mismatch_col[i]]->score < det_thresh) continue; + nnext_candidates[mismatch_col[i]]->activate(count, timestamp); + activated_trajectories.push_back(nnext_candidates[mismatch_col[i]]); + } + + for (size_t i = 0; i < this->lost_trajectories.size(); ++i) { + Trajectory < = this->lost_trajectories[i]; + if (timestamp - lt.timestamp > max_lost_time) { + lt.mark_removed(); + removed_trajectories.push_back(<); + } + } + + TrajectoryPoolIterator piter; + for (piter = this->tracked_trajectories.begin(); + piter != this->tracked_trajectories.end();) { + if (piter->state != Tracked) + piter = this->tracked_trajectories.erase(piter); + else + ++piter; + } + + this->tracked_trajectories += activated_trajectories; + this->tracked_trajectories += retrieved_trajectories; + + this->lost_trajectories -= this->tracked_trajectories; + this->lost_trajectories += lost_trajectories; + this->lost_trajectories -= this->removed_trajectories; + this->removed_trajectories += removed_trajectories; + remove_duplicate_trajectory(&this->tracked_trajectories, + &this->lost_trajectories); + + tracks->clear(); + for (size_t i = 0; i < this->tracked_trajectories.size(); ++i) { + if (this->tracked_trajectories[i].is_activated) { + Track track = {this->tracked_trajectories[i].id, + this->tracked_trajectories[i].score, + this->tracked_trajectories[i].ltrb}; + tracks->push_back(track); + } + } + return 0; +} + +cv::Mat JDETracker::motion_distance(const TrajectoryPtrPool &a, + const TrajectoryPool &b) { + if (0 == a.size() || 0 == b.size()) + return cv::Mat(a.size(), b.size(), CV_32F); + + cv::Mat edists = embedding_distance(a, b); + cv::Mat mdists = mahalanobis_distance(a, b); + cv::Mat fdists = lambda * edists + (1 - lambda) * mdists; + + const float gate_thresh = chi2inv95[4]; + for (int i = 0; i < fdists.rows; ++i) { + for (int j = 0; j < fdists.cols; ++j) { + if (*mdists.ptr(i, j) > gate_thresh) + *fdists.ptr(i, j) = FLT_MAX; + } + } + + return fdists; +} + +void JDETracker::linear_assignment(const cv::Mat &cost, + float cost_limit, + Match *matches, + std::vector *mismatch_row, + std::vector *mismatch_col) { + matches->clear(); + mismatch_row->clear(); + mismatch_col->clear(); + if (cost.empty()) { + for (int i = 0; i < cost.rows; ++i) mismatch_row->push_back(i); + for (int i = 0; i < cost.cols; ++i) mismatch_col->push_back(i); + return; + } + + float opt = 0; + cv::Mat x(cost.rows, 1, CV_32S); + cv::Mat y(cost.cols, 1, CV_32S); + + lapjv_internal(cost, + true, + cost_limit, + reinterpret_cast(x.data), + reinterpret_cast(y.data)); + + for (int i = 0; i < x.rows; ++i) { + int j = *x.ptr(i); + if (j >= 0) + matches->insert({i, j}); + else + mismatch_row->push_back(i); + } + + for (int i = 0; i < y.rows; ++i) { + int j = *y.ptr(i); + if (j < 0) mismatch_col->push_back(i); + } + + return; +} + +void JDETracker::remove_duplicate_trajectory(TrajectoryPool *a, + TrajectoryPool *b, + float iou_thresh) { + if (a->size() == 0 || b->size() == 0) return; + + cv::Mat dist = iou_distance(*a, *b); + cv::Mat mask = dist < iou_thresh; + std::vector idx; + cv::findNonZero(mask, idx); + + std::vector da; + std::vector db; + for (size_t i = 0; i < idx.size(); ++i) { + int ta = (*a)[idx[i].y].timestamp - (*a)[idx[i].y].starttime; + int tb = (*b)[idx[i].x].timestamp - (*b)[idx[i].x].starttime; + if (ta > tb) + db.push_back(idx[i].x); + else + da.push_back(idx[i].y); + } + + int id = 0; + TrajectoryPoolIterator piter; + for (piter = a->begin(); piter != a->end();) { + std::vector::iterator iter = find(da.begin(), da.end(), id++); + if (iter != da.end()) + piter = a->erase(piter); + else + ++piter; + } + + id = 0; + for (piter = b->begin(); piter != b->end();) { + std::vector::iterator iter = find(db.begin(), db.end(), id++); + if (iter != db.end()) + piter = b->erase(piter); + else + ++piter; + } +} + +} // namespace tracking +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/tracking/pptracking/tracker.h b/fastdeploy/vision/tracking/pptracking/tracker.h new file mode 100644 index 000000000..92344450e --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/tracker.h @@ -0,0 +1,78 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The code is based on: +// https://github.com/CnybTseng/JDE/blob/master/platforms/common/jdetracker.h +// Ths copyright of CnybTseng/JDE is as follows: +// MIT License + +#pragma once + +#include +#include + +#include +#include +#include +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/tracking/pptracking/trajectory.h" + +namespace fastdeploy { +namespace vision { +namespace tracking { + +typedef std::map Match; +typedef std::map::iterator MatchIterator; + +struct Track { + int id; + float score; + cv::Vec4f ltrb; +}; + +class FASTDEPLOY_DECL JDETracker { + public: + + JDETracker(); + + virtual bool update(const cv::Mat &dets, + const cv::Mat &emb, + std::vector *tracks); + virtual ~JDETracker() {} + private: + + cv::Mat motion_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b); + void linear_assignment(const cv::Mat &cost, + float cost_limit, + Match *matches, + std::vector *mismatch_row, + std::vector *mismatch_col); + void remove_duplicate_trajectory(TrajectoryPool *a, + TrajectoryPool *b, + float iou_thresh = 0.15f); + + private: + int timestamp; + TrajectoryPool tracked_trajectories; + TrajectoryPool lost_trajectories; + TrajectoryPool removed_trajectories; + int max_lost_time; + float lambda; + float det_thresh; + int count = 0; +}; + +} // namespace tracking +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/tracking/pptracking/trajectory.cc b/fastdeploy/vision/tracking/pptracking/trajectory.cc new file mode 100644 index 000000000..fd54b4c2e --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/trajectory.cc @@ -0,0 +1,519 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The code is based on: +// https://github.com/CnybTseng/JDE/blob/master/platforms/common/trajectory.cpp +// Ths copyright of CnybTseng/JDE is as follows: +// MIT License + +#include "fastdeploy/vision/tracking/pptracking/trajectory.h" +#include + +namespace fastdeploy { +namespace vision { +namespace tracking { + +void TKalmanFilter::init(const cv::Mat &measurement) { + measurement.copyTo(statePost(cv::Rect(0, 0, 1, 4))); + statePost(cv::Rect(0, 4, 1, 4)).setTo(0); + statePost.copyTo(statePre); + + float varpos = 2 * std_weight_position * (*measurement.ptr(3)); + varpos *= varpos; + float varvel = 10 * std_weight_velocity * (*measurement.ptr(3)); + varvel *= varvel; + + errorCovPost.setTo(0); + *errorCovPost.ptr(0, 0) = varpos; + *errorCovPost.ptr(1, 1) = varpos; + *errorCovPost.ptr(2, 2) = 1e-4f; + *errorCovPost.ptr(3, 3) = varpos; + *errorCovPost.ptr(4, 4) = varvel; + *errorCovPost.ptr(5, 5) = varvel; + *errorCovPost.ptr(6, 6) = 1e-10f; + *errorCovPost.ptr(7, 7) = varvel; + errorCovPost.copyTo(errorCovPre); +} + +const cv::Mat &TKalmanFilter::predict() { + float varpos = std_weight_position * (*statePre.ptr(3)); + varpos *= varpos; + float varvel = std_weight_velocity * (*statePre.ptr(3)); + varvel *= varvel; + + processNoiseCov.setTo(0); + *processNoiseCov.ptr(0, 0) = varpos; + *processNoiseCov.ptr(1, 1) = varpos; + *processNoiseCov.ptr(2, 2) = 1e-4f; + *processNoiseCov.ptr(3, 3) = varpos; + *processNoiseCov.ptr(4, 4) = varvel; + *processNoiseCov.ptr(5, 5) = varvel; + *processNoiseCov.ptr(6, 6) = 1e-10f; + *processNoiseCov.ptr(7, 7) = varvel; + + return cv::KalmanFilter::predict(); +} + +const cv::Mat &TKalmanFilter::correct(const cv::Mat &measurement) { + float varpos = std_weight_position * (*measurement.ptr(3)); + varpos *= varpos; + + measurementNoiseCov.setTo(0); + *measurementNoiseCov.ptr(0, 0) = varpos; + *measurementNoiseCov.ptr(1, 1) = varpos; + *measurementNoiseCov.ptr(2, 2) = 1e-2f; + *measurementNoiseCov.ptr(3, 3) = varpos; + + return cv::KalmanFilter::correct(measurement); +} + +void TKalmanFilter::project(cv::Mat *mean, cv::Mat *covariance) const { + float varpos = std_weight_position * (*statePost.ptr(3)); + varpos *= varpos; + + cv::Mat measurementNoiseCov_ = cv::Mat::eye(4, 4, CV_32F); + *measurementNoiseCov_.ptr(0, 0) = varpos; + *measurementNoiseCov_.ptr(1, 1) = varpos; + *measurementNoiseCov_.ptr(2, 2) = 1e-2f; + *measurementNoiseCov_.ptr(3, 3) = varpos; + + *mean = measurementMatrix * statePost; + cv::Mat temp = measurementMatrix * errorCovPost; + gemm(temp, + measurementMatrix, + 1, + measurementNoiseCov_, + 1, + *covariance, + cv::GEMM_2_T); +} + +const cv::Mat &Trajectory::predict(void) { + if (state != Tracked) *cv::KalmanFilter::statePost.ptr(7) = 0; + return TKalmanFilter::predict(); +} + +void Trajectory::update(Trajectory *traj, + int timestamp_, + bool update_embedding_) { + timestamp = timestamp_; + ++length; + ltrb = traj->ltrb; + xyah = traj->xyah; + TKalmanFilter::correct(cv::Mat(traj->xyah)); + state = Tracked; + is_activated = true; + score = traj->score; + if (update_embedding_) update_embedding(traj->current_embedding); +} + +void Trajectory::activate(int &cnt,int timestamp_) { + id = next_id(cnt); + TKalmanFilter::init(cv::Mat(xyah)); + length = 0; + state = Tracked; + if (timestamp_ == 1) { + is_activated = true; + } + timestamp = timestamp_; + starttime = timestamp_; +} + +void Trajectory::reactivate(Trajectory *traj,int &cnt, int timestamp_, bool newid) { + TKalmanFilter::correct(cv::Mat(traj->xyah)); + update_embedding(traj->current_embedding); + length = 0; + state = Tracked; + is_activated = true; + timestamp = timestamp_; + if (newid) id = next_id(cnt); +} + +void Trajectory::update_embedding(const cv::Mat &embedding) { + current_embedding = embedding / cv::norm(embedding); + if (smooth_embedding.empty()) { + smooth_embedding = current_embedding; + } else { + smooth_embedding = eta * smooth_embedding + (1 - eta) * current_embedding; + } + smooth_embedding = smooth_embedding / cv::norm(smooth_embedding); +} + +TrajectoryPool operator+(const TrajectoryPool &a, const TrajectoryPool &b) { + TrajectoryPool sum; + sum.insert(sum.end(), a.begin(), a.end()); + + std::vector ids(a.size()); + for (size_t i = 0; i < a.size(); ++i) ids[i] = a[i].id; + + for (size_t i = 0; i < b.size(); ++i) { + std::vector::iterator iter = find(ids.begin(), ids.end(), b[i].id); + if (iter == ids.end()) { + sum.push_back(b[i]); + ids.push_back(b[i].id); + } + } + + return sum; +} + +TrajectoryPool operator+(const TrajectoryPool &a, const TrajectoryPtrPool &b) { + TrajectoryPool sum; + sum.insert(sum.end(), a.begin(), a.end()); + + std::vector ids(a.size()); + for (size_t i = 0; i < a.size(); ++i) ids[i] = a[i].id; + + for (size_t i = 0; i < b.size(); ++i) { + std::vector::iterator iter = find(ids.begin(), ids.end(), b[i]->id); + if (iter == ids.end()) { + sum.push_back(*b[i]); + ids.push_back(b[i]->id); + } + } + + return sum; +} + +TrajectoryPool &operator+=(TrajectoryPool &a, // NOLINT + const TrajectoryPtrPool &b) { + std::vector ids(a.size()); + for (size_t i = 0; i < a.size(); ++i) ids[i] = a[i].id; + + for (size_t i = 0; i < b.size(); ++i) { + if (b[i]->smooth_embedding.empty()) continue; + std::vector::iterator iter = find(ids.begin(), ids.end(), b[i]->id); + if (iter == ids.end()) { + a.push_back(*b[i]); + ids.push_back(b[i]->id); + } + } + + return a; +} + +TrajectoryPool operator-(const TrajectoryPool &a, const TrajectoryPool &b) { + TrajectoryPool dif; + std::vector ids(b.size()); + for (size_t i = 0; i < b.size(); ++i) ids[i] = b[i].id; + + for (size_t i = 0; i < a.size(); ++i) { + std::vector::iterator iter = find(ids.begin(), ids.end(), a[i].id); + if (iter == ids.end()) dif.push_back(a[i]); + } + + return dif; +} + +TrajectoryPool &operator-=(TrajectoryPool &a, // NOLINT + const TrajectoryPool &b) { + std::vector ids(b.size()); + for (size_t i = 0; i < b.size(); ++i) ids[i] = b[i].id; + + TrajectoryPoolIterator piter; + for (piter = a.begin(); piter != a.end();) { + std::vector::iterator iter = find(ids.begin(), ids.end(), piter->id); + if (iter == ids.end()) + ++piter; + else + piter = a.erase(piter); + } + + return a; +} + +TrajectoryPtrPool operator+(const TrajectoryPtrPool &a, + const TrajectoryPtrPool &b) { + TrajectoryPtrPool sum; + sum.insert(sum.end(), a.begin(), a.end()); + + std::vector ids(a.size()); + for (size_t i = 0; i < a.size(); ++i) ids[i] = a[i]->id; + + for (size_t i = 0; i < b.size(); ++i) { + std::vector::iterator iter = find(ids.begin(), ids.end(), b[i]->id); + if (iter == ids.end()) { + sum.push_back(b[i]); + ids.push_back(b[i]->id); + } + } + + return sum; +} + +TrajectoryPtrPool operator+(const TrajectoryPtrPool &a, TrajectoryPool *b) { + TrajectoryPtrPool sum; + sum.insert(sum.end(), a.begin(), a.end()); + + std::vector ids(a.size()); + for (size_t i = 0; i < a.size(); ++i) ids[i] = a[i]->id; + + for (size_t i = 0; i < b->size(); ++i) { + std::vector::iterator iter = find(ids.begin(), ids.end(), (*b)[i].id); + if (iter == ids.end()) { + sum.push_back(&(*b)[i]); + ids.push_back((*b)[i].id); + } + } + + return sum; +} + +TrajectoryPtrPool operator-(const TrajectoryPtrPool &a, + const TrajectoryPtrPool &b) { + TrajectoryPtrPool dif; + std::vector ids(b.size()); + for (size_t i = 0; i < b.size(); ++i) ids[i] = b[i]->id; + + for (size_t i = 0; i < a.size(); ++i) { + std::vector::iterator iter = find(ids.begin(), ids.end(), a[i]->id); + if (iter == ids.end()) dif.push_back(a[i]); + } + + return dif; +} + +cv::Mat embedding_distance(const TrajectoryPool &a, const TrajectoryPool &b) { + cv::Mat dists(a.size(), b.size(), CV_32F); + for (size_t i = 0; i < a.size(); ++i) { + float *distsi = dists.ptr(i); + for (size_t j = 0; j < b.size(); ++j) { + cv::Mat u = a[i].smooth_embedding; + cv::Mat v = b[j].smooth_embedding; + double uv = u.dot(v); + double uu = u.dot(u); + double vv = v.dot(v); + double dist = std::abs(1. - uv / std::sqrt(uu * vv)); + // double dist = cv::norm(a[i].smooth_embedding, b[j].smooth_embedding, + // cv::NORM_L2); + distsi[j] = static_cast(std::max(std::min(dist, 2.), 0.)); + } + } + return dists; +} + +cv::Mat embedding_distance(const TrajectoryPtrPool &a, + const TrajectoryPtrPool &b) { + cv::Mat dists(a.size(), b.size(), CV_32F); + for (size_t i = 0; i < a.size(); ++i) { + float *distsi = dists.ptr(i); + for (size_t j = 0; j < b.size(); ++j) { + // double dist = cv::norm(a[i]->smooth_embedding, b[j]->smooth_embedding, + // cv::NORM_L2); + // distsi[j] = static_cast(dist); + cv::Mat u = a[i]->smooth_embedding; + cv::Mat v = b[j]->smooth_embedding; + double uv = u.dot(v); + double uu = u.dot(u); + double vv = v.dot(v); + double dist = std::abs(1. - uv / std::sqrt(uu * vv)); + distsi[j] = static_cast(std::max(std::min(dist, 2.), 0.)); + } + } + + return dists; +} + +cv::Mat embedding_distance(const TrajectoryPtrPool &a, + const TrajectoryPool &b) { + cv::Mat dists(a.size(), b.size(), CV_32F); + for (size_t i = 0; i < a.size(); ++i) { + float *distsi = dists.ptr(i); + for (size_t j = 0; j < b.size(); ++j) { + // double dist = cv::norm(a[i]->smooth_embedding, b[j].smooth_embedding, + // cv::NORM_L2); + // distsi[j] = static_cast(dist); + cv::Mat u = a[i]->smooth_embedding; + cv::Mat v = b[j].smooth_embedding; + double uv = u.dot(v); + double uu = u.dot(u); + double vv = v.dot(v); + double dist = std::abs(1. - uv / std::sqrt(uu * vv)); + distsi[j] = static_cast(std::max(std::min(dist, 2.), 0.)); + } + } + + return dists; +} + +cv::Mat mahalanobis_distance(const TrajectoryPool &a, const TrajectoryPool &b) { + std::vector means(a.size()); + std::vector icovariances(a.size()); + for (size_t i = 0; i < a.size(); ++i) { + cv::Mat covariance; + a[i].project(&means[i], &covariance); + cv::invert(covariance, icovariances[i]); + } + + cv::Mat dists(a.size(), b.size(), CV_32F); + for (size_t i = 0; i < a.size(); ++i) { + float *distsi = dists.ptr(i); + for (size_t j = 0; j < b.size(); ++j) { + const cv::Mat x(b[j].xyah); + float dist = + static_cast(cv::Mahalanobis(x, means[i], icovariances[i])); + distsi[j] = dist * dist; + } + } + + return dists; +} + +cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a, + const TrajectoryPtrPool &b) { + std::vector means(a.size()); + std::vector icovariances(a.size()); + for (size_t i = 0; i < a.size(); ++i) { + cv::Mat covariance; + a[i]->project(&means[i], &covariance); + cv::invert(covariance, icovariances[i]); + } + + cv::Mat dists(a.size(), b.size(), CV_32F); + for (size_t i = 0; i < a.size(); ++i) { + float *distsi = dists.ptr(i); + for (size_t j = 0; j < b.size(); ++j) { + const cv::Mat x(b[j]->xyah); + float dist = + static_cast(cv::Mahalanobis(x, means[i], icovariances[i])); + distsi[j] = dist * dist; + } + } + + return dists; +} + +cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a, + const TrajectoryPool &b) { + std::vector means(a.size()); + std::vector icovariances(a.size()); + + for (size_t i = 0; i < a.size(); ++i) { + cv::Mat covariance; + a[i]->project(&means[i], &covariance); + cv::invert(covariance, icovariances[i]); + } + + cv::Mat dists(a.size(), b.size(), CV_32F); + for (size_t i = 0; i < a.size(); ++i) { + float *distsi = dists.ptr(i); + for (size_t j = 0; j < b.size(); ++j) { + const cv::Mat x(b[j].xyah); + float dist = + static_cast(cv::Mahalanobis(x, means[i], icovariances[i])); + distsi[j] = dist * dist; + } + } + + return dists; +} + +static inline float calc_inter_area(const cv::Vec4f &a, const cv::Vec4f &b) { + if (a[2] < b[0] || a[0] > b[2] || a[3] < b[1] || a[1] > b[3]) return 0.f; + + float w = std::min(a[2], b[2]) - std::max(a[0], b[0]); + float h = std::min(a[3], b[3]) - std::max(a[1], b[1]); + return w * h; +} + +cv::Mat iou_distance(const TrajectoryPool &a, const TrajectoryPool &b) { + std::vector areaa(a.size()); + for (size_t i = 0; i < a.size(); ++i) { + float w = a[i].ltrb[2] - a[i].ltrb[0]; + float h = a[i].ltrb[3] - a[i].ltrb[1]; + areaa[i] = w * h; + } + + std::vector areab(b.size()); + for (size_t j = 0; j < b.size(); ++j) { + float w = b[j].ltrb[2] - b[j].ltrb[0]; + float h = b[j].ltrb[3] - b[j].ltrb[1]; + areab[j] = w * h; + } + + cv::Mat dists(a.size(), b.size(), CV_32F); + for (size_t i = 0; i < a.size(); ++i) { + const cv::Vec4f &boxa = a[i].ltrb; + float *distsi = dists.ptr(i); + for (size_t j = 0; j < b.size(); ++j) { + const cv::Vec4f &boxb = b[j].ltrb; + float inters = calc_inter_area(boxa, boxb); + distsi[j] = 1.f - inters / (areaa[i] + areab[j] - inters); + } + } + + return dists; +} + +cv::Mat iou_distance(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b) { + std::vector areaa(a.size()); + for (size_t i = 0; i < a.size(); ++i) { + float w = a[i]->ltrb[2] - a[i]->ltrb[0]; + float h = a[i]->ltrb[3] - a[i]->ltrb[1]; + areaa[i] = w * h; + } + + std::vector areab(b.size()); + for (size_t j = 0; j < b.size(); ++j) { + float w = b[j]->ltrb[2] - b[j]->ltrb[0]; + float h = b[j]->ltrb[3] - b[j]->ltrb[1]; + areab[j] = w * h; + } + + cv::Mat dists(a.size(), b.size(), CV_32F); + for (size_t i = 0; i < a.size(); ++i) { + const cv::Vec4f &boxa = a[i]->ltrb; + float *distsi = dists.ptr(i); + for (size_t j = 0; j < b.size(); ++j) { + const cv::Vec4f &boxb = b[j]->ltrb; + float inters = calc_inter_area(boxa, boxb); + distsi[j] = 1.f - inters / (areaa[i] + areab[j] - inters); + } + } + + return dists; +} + +cv::Mat iou_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b) { + std::vector areaa(a.size()); + for (size_t i = 0; i < a.size(); ++i) { + float w = a[i]->ltrb[2] - a[i]->ltrb[0]; + float h = a[i]->ltrb[3] - a[i]->ltrb[1]; + areaa[i] = w * h; + } + + std::vector areab(b.size()); + for (size_t j = 0; j < b.size(); ++j) { + float w = b[j].ltrb[2] - b[j].ltrb[0]; + float h = b[j].ltrb[3] - b[j].ltrb[1]; + areab[j] = w * h; + } + + cv::Mat dists(a.size(), b.size(), CV_32F); + for (size_t i = 0; i < a.size(); ++i) { + const cv::Vec4f &boxa = a[i]->ltrb; + float *distsi = dists.ptr(i); + for (size_t j = 0; j < b.size(); ++j) { + const cv::Vec4f &boxb = b[j].ltrb; + float inters = calc_inter_area(boxa, boxb); + distsi[j] = 1.f - inters / (areaa[i] + areab[j] - inters); + } + } + + return dists; +} + +} // namespace tracking +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/tracking/pptracking/trajectory.h b/fastdeploy/vision/tracking/pptracking/trajectory.h new file mode 100644 index 000000000..a869f8409 --- /dev/null +++ b/fastdeploy/vision/tracking/pptracking/trajectory.h @@ -0,0 +1,234 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// The code is based on: +// https://github.com/CnybTseng/JDE/blob/master/platforms/common/trajectory.h +// Ths copyright of CnybTseng/JDE is as follows: +// MIT License + +#pragma once + +#include +#include "fastdeploy/fastdeploy_model.h" +#include +#include +#include +#include "opencv2/video/tracking.hpp" + +namespace fastdeploy { +namespace vision { +namespace tracking { + +typedef enum { New = 0, Tracked = 1, Lost = 2, Removed = 3 } TrajectoryState; + +class Trajectory; +typedef std::vector TrajectoryPool; +typedef std::vector::iterator TrajectoryPoolIterator; +typedef std::vector TrajectoryPtrPool; +typedef std::vector::iterator TrajectoryPtrPoolIterator; + +class FASTDEPLOY_DECL TKalmanFilter : public cv::KalmanFilter { + public: + TKalmanFilter(void); + virtual ~TKalmanFilter(void) {} + virtual void init(const cv::Mat &measurement); + virtual const cv::Mat &predict(); + virtual const cv::Mat &correct(const cv::Mat &measurement); + virtual void project(cv::Mat *mean, cv::Mat *covariance) const; + + private: + float std_weight_position; + float std_weight_velocity; +}; + +inline TKalmanFilter::TKalmanFilter(void) : cv::KalmanFilter(8, 4) { + cv::KalmanFilter::transitionMatrix = cv::Mat::eye(8, 8, CV_32F); + for (int i = 0; i < 4; ++i) + cv::KalmanFilter::transitionMatrix.at(i, i + 4) = 1; + cv::KalmanFilter::measurementMatrix = cv::Mat::eye(4, 8, CV_32F); + std_weight_position = 1 / 20.f; + std_weight_velocity = 1 / 160.f; +} + +class FASTDEPLOY_DECL Trajectory : public TKalmanFilter { + public: + Trajectory(); + Trajectory(const cv::Vec4f <rb, float score, const cv::Mat &embedding); + Trajectory(const Trajectory &other); + Trajectory &operator=(const Trajectory &rhs); + virtual ~Trajectory(void) {} + + int next_id(int &nt); + virtual const cv::Mat &predict(void); + virtual void update(Trajectory *traj, + int timestamp, + bool update_embedding = true); + virtual void activate(int& cnt, int timestamp); + virtual void reactivate(Trajectory *traj, int & cnt,int timestamp, bool newid = false); + virtual void mark_lost(void); + virtual void mark_removed(void); + + friend TrajectoryPool operator+(const TrajectoryPool &a, + const TrajectoryPool &b); + friend TrajectoryPool operator+(const TrajectoryPool &a, + const TrajectoryPtrPool &b); + friend TrajectoryPool &operator+=(TrajectoryPool &a, // NOLINT + const TrajectoryPtrPool &b); + friend TrajectoryPool operator-(const TrajectoryPool &a, + const TrajectoryPool &b); + friend TrajectoryPool &operator-=(TrajectoryPool &a, // NOLINT + const TrajectoryPool &b); + friend TrajectoryPtrPool operator+(const TrajectoryPtrPool &a, + const TrajectoryPtrPool &b); + friend TrajectoryPtrPool operator+(const TrajectoryPtrPool &a, + TrajectoryPool *b); + friend TrajectoryPtrPool operator-(const TrajectoryPtrPool &a, + const TrajectoryPtrPool &b); + + friend cv::Mat embedding_distance(const TrajectoryPool &a, + const TrajectoryPool &b); + friend cv::Mat embedding_distance(const TrajectoryPtrPool &a, + const TrajectoryPtrPool &b); + friend cv::Mat embedding_distance(const TrajectoryPtrPool &a, + const TrajectoryPool &b); + + friend cv::Mat mahalanobis_distance(const TrajectoryPool &a, + const TrajectoryPool &b); + friend cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a, + const TrajectoryPtrPool &b); + friend cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a, + const TrajectoryPool &b); + + friend cv::Mat iou_distance(const TrajectoryPool &a, const TrajectoryPool &b); + friend cv::Mat iou_distance(const TrajectoryPtrPool &a, + const TrajectoryPtrPool &b); + friend cv::Mat iou_distance(const TrajectoryPtrPool &a, + const TrajectoryPool &b); + + private: + void update_embedding(const cv::Mat &embedding); + + public: + TrajectoryState state; + cv::Vec4f ltrb; + cv::Mat smooth_embedding; + int id; + bool is_activated; + int timestamp; + int starttime; + float score; + + private: +// int count=0; + cv::Vec4f xyah; + cv::Mat current_embedding; + float eta; + int length; +}; + +inline cv::Vec4f ltrb2xyah(const cv::Vec4f <rb) { + cv::Vec4f xyah; + xyah[0] = (ltrb[0] + ltrb[2]) * 0.5f; + xyah[1] = (ltrb[1] + ltrb[3]) * 0.5f; + xyah[3] = ltrb[3] - ltrb[1]; + xyah[2] = (ltrb[2] - ltrb[0]) / xyah[3]; + return xyah; +} + +inline Trajectory::Trajectory() + : state(New), + ltrb(cv::Vec4f()), + smooth_embedding(cv::Mat()), + id(0), + is_activated(false), + timestamp(0), + starttime(0), + score(0), + eta(0.9), + length(0) {} + +inline Trajectory::Trajectory(const cv::Vec4f <rb_, + float score_, + const cv::Mat &embedding) + : state(New), + ltrb(ltrb_), + smooth_embedding(cv::Mat()), + id(0), + is_activated(false), + timestamp(0), + starttime(0), + score(score_), + eta(0.9), + length(0) { + xyah = ltrb2xyah(ltrb); + update_embedding(embedding); +} + +inline Trajectory::Trajectory(const Trajectory &other) + : state(other.state), + ltrb(other.ltrb), + id(other.id), + is_activated(other.is_activated), + timestamp(other.timestamp), + starttime(other.starttime), + xyah(other.xyah), + score(other.score), + eta(other.eta), + length(other.length) { + other.smooth_embedding.copyTo(smooth_embedding); + other.current_embedding.copyTo(current_embedding); + // copy state in KalmanFilter + + other.statePre.copyTo(cv::KalmanFilter::statePre); + other.statePost.copyTo(cv::KalmanFilter::statePost); + other.errorCovPre.copyTo(cv::KalmanFilter::errorCovPre); + other.errorCovPost.copyTo(cv::KalmanFilter::errorCovPost); +} + +inline Trajectory &Trajectory::operator=(const Trajectory &rhs) { + this->state = rhs.state; + this->ltrb = rhs.ltrb; + rhs.smooth_embedding.copyTo(this->smooth_embedding); + this->id = rhs.id; + this->is_activated = rhs.is_activated; + this->timestamp = rhs.timestamp; + this->starttime = rhs.starttime; + this->xyah = rhs.xyah; + this->score = rhs.score; + rhs.current_embedding.copyTo(this->current_embedding); + this->eta = rhs.eta; + this->length = rhs.length; + + // copy state in KalmanFilter + + rhs.statePre.copyTo(cv::KalmanFilter::statePre); + rhs.statePost.copyTo(cv::KalmanFilter::statePost); + rhs.errorCovPre.copyTo(cv::KalmanFilter::errorCovPre); + rhs.errorCovPost.copyTo(cv::KalmanFilter::errorCovPost); + + return *this; +} + +inline int Trajectory::next_id(int &cnt) { + ++cnt; + return cnt; +} + +inline void Trajectory::mark_lost(void) { state = Lost; } + +inline void Trajectory::mark_removed(void) { state = Removed; } + +} // namespace tracking +} // namespace vision +} // namespace fastdeploy diff --git a/fastdeploy/vision/tracking/tracking_pybind.cc b/fastdeploy/vision/tracking/tracking_pybind.cc new file mode 100644 index 000000000..7341c85f2 --- /dev/null +++ b/fastdeploy/vision/tracking/tracking_pybind.cc @@ -0,0 +1,26 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/pybind/main.h" + +namespace fastdeploy { + + void BindPPTracking(pybind11::module& m); + + void BindTracking(pybind11::module& m) { + auto tracking_module = + m.def_submodule("tracking", "object tracking models."); + BindPPTracking(tracking_module); + } +} // namespace fastdeploy diff --git a/fastdeploy/vision/vision_pybind.cc b/fastdeploy/vision/vision_pybind.cc index 70f1990a4..89b08f119 100644 --- a/fastdeploy/vision/vision_pybind.cc +++ b/fastdeploy/vision/vision_pybind.cc @@ -23,6 +23,7 @@ void BindMatting(pybind11::module& m); void BindFaceDet(pybind11::module& m); void BindFaceId(pybind11::module& m); void BindOcr(pybind11::module& m); +void BindTracking(pybind11::module& m); void BindKeyPointDetection(pybind11::module& m); #ifdef ENABLE_VISION_VISUALIZE void BindVisualize(pybind11::module& m); @@ -63,6 +64,15 @@ void BindVision(pybind11::module& m) { .def("__repr__", &vision::OCRResult::Str) .def("__str__", &vision::OCRResult::Str); + pybind11::class_(m, "MOTResult") + .def(pybind11::init()) + .def_readwrite("boxes", &vision::MOTResult::boxes) + .def_readwrite("ids", &vision::MOTResult::ids) + .def_readwrite("scores", &vision::MOTResult::scores) + .def_readwrite("class_ids", &vision::MOTResult::class_ids) + .def("__repr__", &vision::MOTResult::Str) + .def("__str__", &vision::MOTResult::Str); + pybind11::class_(m, "FaceDetectionResult") .def(pybind11::init()) .def_readwrite("boxes", &vision::FaceDetectionResult::boxes) @@ -112,6 +122,7 @@ void BindVision(pybind11::module& m) { BindFaceId(m); BindMatting(m); BindOcr(m); + BindTracking(m); BindKeyPointDetection(m); #ifdef ENABLE_VISION_VISUALIZE BindVisualize(m); diff --git a/fastdeploy/vision/visualize/mot.cc b/fastdeploy/vision/visualize/mot.cc new file mode 100644 index 000000000..9877b2d4e --- /dev/null +++ b/fastdeploy/vision/visualize/mot.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef ENABLE_VISION_VISUALIZE +#include "fastdeploy/vision/visualize/visualize.h" +#include + +namespace fastdeploy { +namespace vision { + +cv::Scalar GetMOTBoxColor(int idx) { + idx = idx * 3; + cv::Scalar color = cv::Scalar((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255); + return color; +} + + +cv::Mat VisMOT(const cv::Mat &img, const MOTResult &results, float fps, int frame_id) { + + cv::Mat vis_img = img.clone(); + int im_h = img.rows; + int im_w = img.cols; + float text_scale = std::max(1, static_cast(im_w / 1600.)); + float text_thickness = 2.; + float line_thickness = std::max(1, static_cast(im_w / 500.)); + + std::ostringstream oss; + oss << std::setiosflags(std::ios::fixed) << std::setprecision(4); + oss << "frame: " << frame_id << " "; + oss << "fps: " << fps << " "; + oss << "num: " << results.boxes.size(); + std::string text = oss.str(); + + cv::Point origin; + origin.x = 0; + origin.y = static_cast(15 * text_scale); + cv::putText(vis_img, + text, + origin, + cv::FONT_HERSHEY_PLAIN, + text_scale, + cv::Scalar(0, 0, 255), + text_thickness); + + for (int i = 0; i < results.boxes.size(); ++i) { + const int obj_id = results.ids[i]; + const float score = results.scores[i]; + + cv::Scalar color = GetMOTBoxColor(obj_id); + + cv::Point pt1 = cv::Point(results.boxes[i][0], results.boxes[i][1]); + cv::Point pt2 = cv::Point(results.boxes[i][2], results.boxes[i][3]); + cv::Point id_pt = + cv::Point(results.boxes[i][0], results.boxes[i][1] + 10); + cv::Point score_pt = + cv::Point(results.boxes[i][0], results.boxes[i][1] - 10); + cv::rectangle(vis_img, pt1, pt2, color, line_thickness); + + std::ostringstream idoss; + idoss << std::setiosflags(std::ios::fixed) << std::setprecision(4); + idoss << obj_id; + std::string id_text = idoss.str(); + + cv::putText(vis_img, + id_text, + id_pt, + cv::FONT_HERSHEY_PLAIN, + text_scale, + cv::Scalar(0, 255, 255), + text_thickness); + + std::ostringstream soss; + soss << std::setiosflags(std::ios::fixed) << std::setprecision(2); + soss << score; + std::string score_text = soss.str(); + + cv::putText(vis_img, + score_text, + score_pt, + cv::FONT_HERSHEY_PLAIN, + text_scale, + cv::Scalar(0, 255, 255), + text_thickness); + } + return vis_img; +} +}// namespace vision +} //namespace fastdepoly +#endif diff --git a/fastdeploy/vision/visualize/ocr.cc b/fastdeploy/vision/visualize/ocr.cc index 5a47b4909..ac8f36312 100644 --- a/fastdeploy/vision/visualize/ocr.cc +++ b/fastdeploy/vision/visualize/ocr.cc @@ -15,7 +15,6 @@ #ifdef ENABLE_VISION_VISUALIZE #include "fastdeploy/vision/visualize/visualize.h" -#include "opencv2/imgproc/imgproc.hpp" namespace fastdeploy { namespace vision { diff --git a/fastdeploy/vision/visualize/visualize.h b/fastdeploy/vision/visualize/visualize.h index 256ffba70..6882156c1 100644 --- a/fastdeploy/vision/visualize/visualize.h +++ b/fastdeploy/vision/visualize/visualize.h @@ -77,6 +77,9 @@ FASTDEPLOY_DECL cv::Mat VisMatting(const cv::Mat& im, const MattingResult& result, bool remove_small_connected_area = false); FASTDEPLOY_DECL cv::Mat VisOcr(const cv::Mat& im, const OCRResult& ocr_result); + +FASTDEPLOY_DECL cv::Mat VisMOT(const cv::Mat& img,const MOTResult& results, float fps=0.0, int frame_id=0); + FASTDEPLOY_DECL cv::Mat SwapBackground( const cv::Mat& im, const cv::Mat& background, const MattingResult& result, bool remove_small_connected_area = false); diff --git a/fastdeploy/vision/visualize/visualize_pybind.cc b/fastdeploy/vision/visualize/visualize_pybind.cc index 83ccb34f4..fcb7bd39f 100644 --- a/fastdeploy/vision/visualize/visualize_pybind.cc +++ b/fastdeploy/vision/visualize/visualize_pybind.cc @@ -75,6 +75,14 @@ void BindVisualize(pybind11::module& m) { vision::Mat(vis_im).ShareWithTensor(&out); return TensorToPyArray(out); }) + .def("vis_mot", + [](pybind11::array& im_data, vision::MOTResult& result,float fps, int frame_id) { + auto im = PyArrayToCvMat(im_data); + auto vis_im = vision::VisMOT(im, result,fps,frame_id); + FDTensor out; + vision::Mat(vis_im).ShareWithTensor(&out); + return TensorToPyArray(out); + }) .def("vis_matting", [](pybind11::array& im_data, vision::MattingResult& result, bool remove_small_connected_area) { @@ -166,6 +174,14 @@ void BindVisualize(pybind11::module& m) { vision::Mat(vis_im).ShareWithTensor(&out); return TensorToPyArray(out); }) + .def_static("vis_mot", + [](pybind11::array& im_data, vision::MOTResult& result,float fps, int frame_id) { + auto im = PyArrayToCvMat(im_data); + auto vis_im = vision::VisMOT(im, result,fps,frame_id); + FDTensor out; + vision::Mat(vis_im).ShareWithTensor(&out); + return TensorToPyArray(out); + }) .def_static("vis_matting_alpha", [](pybind11::array& im_data, vision::MattingResult& result, bool remove_small_connected_area) { diff --git a/python/fastdeploy/libs/__init__.py b/python/fastdeploy/libs/__init__.py index 8b1378917..e69de29bb 100644 --- a/python/fastdeploy/libs/__init__.py +++ b/python/fastdeploy/libs/__init__.py @@ -1 +0,0 @@ - diff --git a/python/fastdeploy/vision/__init__.py b/python/fastdeploy/vision/__init__.py index 28cd0564f..ce282daaf 100644 --- a/python/fastdeploy/vision/__init__.py +++ b/python/fastdeploy/vision/__init__.py @@ -16,8 +16,8 @@ from __future__ import absolute_import from . import detection from . import classification from . import segmentation +from . import tracking from . import keypointdetection - from . import matting from . import facedet from . import faceid diff --git a/python/fastdeploy/vision/tracking/__init__.py b/python/fastdeploy/vision/tracking/__init__.py new file mode 100644 index 000000000..946dfd971 --- /dev/null +++ b/python/fastdeploy/vision/tracking/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import + +from .pptracking import PPTracking diff --git a/python/fastdeploy/vision/tracking/pptracking/__init__.py b/python/fastdeploy/vision/tracking/pptracking/__init__.py new file mode 100644 index 000000000..9da0d669f --- /dev/null +++ b/python/fastdeploy/vision/tracking/pptracking/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from .... import FastDeployModel, ModelFormat +from .... import c_lib_wrap as C + + +class PPTracking(FastDeployModel): + def __init__(self, + model_file, + params_file, + config_file, + runtime_option=None, + model_format=ModelFormat.PADDLE): + super(PPTracking, self).__init__(runtime_option) + + assert model_format == ModelFormat.PADDLE, "PPTracking model only support model format of ModelFormat.Paddle now." + self._model = C.vision.tracking.PPTracking( + model_file, params_file, config_file, self._runtime_option, + model_format) + assert self.initialized, "PPTracking model initialize failed." + + def predict(self, input_image): + assert input_image is not None, "The input image data is None." + return self._model.predict(input_image) diff --git a/python/fastdeploy/vision/visualize/__init__.py b/python/fastdeploy/vision/visualize/__init__.py index 95b8e816a..37ac727dc 100644 --- a/python/fastdeploy/vision/visualize/__init__.py +++ b/python/fastdeploy/vision/visualize/__init__.py @@ -98,3 +98,7 @@ def swap_background(im_data, def vis_ppocr(im_data, det_result): return C.vision.vis_ppocr(im_data, det_result) + + +def vis_mot(im_data, mot_result, fps, frame_id): + return C.vision.vis_mot(im_data, mot_result, fps, frame_id) diff --git a/python/scripts/process_libraries.py.in b/python/scripts/process_libraries.py.in index 31fb56ed0..25b7a5e4b 100644 --- a/python/scripts/process_libraries.py.in +++ b/python/scripts/process_libraries.py.in @@ -46,7 +46,7 @@ def process_on_linux(current_dir): if len(items) != 4: os.remove(os.path.join(root, f)) continue - if items[0].strip() not in ["libopencv_highgui", "libopencv_videoio", "libopencv_imgcodecs", "libopencv_imgproc", "libopencv_core"]: + if items[0].strip() not in ["libopencv_highgui", "libopencv_video", "libopencv_videoio", "libopencv_imgcodecs", "libopencv_imgproc", "libopencv_core"]: os.remove(os.path.join(root, f)) all_libs_paths = [third_libs_path] + user_specified_dirs diff --git a/tests/eval_example/test_pptracking.py b/tests/eval_example/test_pptracking.py new file mode 100644 index 000000000..ee1cb9bc5 --- /dev/null +++ b/tests/eval_example/test_pptracking.py @@ -0,0 +1,89 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import fastdeploy as fd +import cv2 +import os +import numpy as np +import pickle + + +def test_pptracking_cpu(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/pptracking.tgz" + input_url = "https://bj.bcebos.com/paddlehub/fastdeploy/person.mp4" + fd.download_and_decompress(model_url, ".") + fd.download(input_url, ".") + model_path = "pptracking/fairmot_hrnetv2_w18_dlafpn_30e_576x320" + # use default backend + runtime_option = fd.RuntimeOption() + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + config_file = os.path.join(model_path, "infer_cfg.yml") + model = fd.vision.tracking.PPTracking(model_file, params_file, config_file, runtime_option=runtime_option) + cap = cv2.VideoCapture("./person.mp4") + frame_id = 0 + while True: + _, frame = cap.read() + if frame is None: + break + result = model.predict(frame) + # compare diff + expect = pickle.load(open("pptracking/frame" + str(frame_id) + ".pkl", "rb")) + diff_boxes = np.fabs(np.array(expect["boxes"]) - np.array(result.boxes)) + diff_scores = np.fabs(np.array(expect["scores"]) - np.array(result.scores)) + diff = max(diff_boxes.max(), diff_scores.max()) + thres = 1e-05 + assert diff < thres, "The label diff is %f, which is bigger than %f" % (diff, thres) + frame_id = frame_id + 1 + cv2.waitKey(30) + if frame_id >= 10: + cap.release() + cv2.destroyAllWindows() + break + + +def test_pptracking_gpu(): + model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/pptracking.tgz" + input_url = "https://bj.bcebos.com/paddlehub/fastdeploy/person.mp4" + fd.download_and_decompress(model_url, ".") + fd.download(input_url, ".") + model_path = "pptracking/fairmot_hrnetv2_w18_dlafpn_30e_576x320" + runtime_option = fd.RuntimeOption() + runtime_option.use_gpu() + # Not supported trt backend, up to now + # runtime_option.use_trt_backend() + model_file = os.path.join(model_path, "model.pdmodel") + params_file = os.path.join(model_path, "model.pdiparams") + config_file = os.path.join(model_path, "infer_cfg.yml") + model = fd.vision.tracking.PPTracking(model_file, params_file, config_file, runtime_option=runtime_option) + cap = cv2.VideoCapture("./person.mp4") + frame_id = 0 + while True: + _, frame = cap.read() + if frame is None: + break + result = model.predict(frame) + # compare diff + expect = pickle.load(open("pptracking/frame" + str(frame_id) + ".pkl", "rb")) + diff_boxes = np.fabs(np.array(expect["boxes"]) - np.array(result.boxes)) + diff_scores = np.fabs(np.array(expect["scores"]) - np.array(result.scores)) + diff = max(diff_boxes.max(), diff_scores.max()) + thres = 1e-05 + assert diff < thres, "The label diff is %f, which is bigger than %f" % (diff, thres) + frame_id = frame_id + 1 + cv2.waitKey(30) + if frame_id >= 10: + cap.release() + cv2.destroyAllWindows() + break