[Model] add pptracking model (#357)

* add override mark

* delete some

* recovery

* recovery

* add tracking

* add tracking py_bind and example

* add pptracking

* add pptracking

* iomanip head file

* add opencv_video lib

* add python libs package

Signed-off-by: ChaoII <849453582@qq.com>

* complete comments

Signed-off-by: ChaoII <849453582@qq.com>

* add jdeTracker_ member variable

Signed-off-by: ChaoII <849453582@qq.com>

* add 'FASTDEPLOY_DECL' macro

Signed-off-by: ChaoII <849453582@qq.com>

* remove kwargs params

Signed-off-by: ChaoII <849453582@qq.com>

* [Doc]update pptracking docs

* delete 'ENABLE_PADDLE_FRONTEND' switch

* add pptracking unit test

* update pptracking unit test

Signed-off-by: ChaoII <849453582@qq.com>

* modify test video file path and remove trt test

* update unit test model url

* remove 'FASTDEPLOY_DECL' macro

Signed-off-by: ChaoII <849453582@qq.com>

* fix build python packages about pptracking on win32

Signed-off-by: ChaoII <849453582@qq.com>

Signed-off-by: ChaoII <849453582@qq.com>
Co-authored-by: Jason <jiangjiajun@baidu.com>
This commit is contained in:
ChaoII
2022-10-26 14:27:55 +08:00
committed by GitHub
parent da7247aa41
commit ba501fd963
38 changed files with 2959 additions and 16 deletions

2
.gitignore vendored
View File

@@ -13,12 +13,14 @@ fastdeploy/version.py
fastdeploy/core/config.h
python/fastdeploy/c_lib_wrap.py
python/fastdeploy/LICENSE*
python/build_cpu.sh
python/fastdeploy/ThirdPartyNotices*
*.so*
fpython/astdeploy/libs/third_libs
fastdeploy/core/config.h
fastdeploy/pybind/main.cc
python/fastdeploy/libs/lib*
python/fastdeploy/libs/third_libs
__pycache__
build_fd_android.sh
python/scripts/process_libraries.py

0
benchmark/run_benchmark_yolo.sh Normal file → Executable file
View File

View File

@@ -186,6 +186,6 @@ else()
endif()
find_package(OpenCV REQUIRED PATHS ${OpenCV_DIR})
include_directories(${OpenCV_INCLUDE_DIRS})
list(APPEND DEPEND_LIBS opencv_core opencv_highgui opencv_imgproc opencv_imgcodecs)
list(APPEND DEPEND_LIBS opencv_core opencv_video opencv_highgui opencv_imgproc opencv_imgcodecs)
endif()
endif()

View File

@@ -3,7 +3,7 @@
FastDeploy根据视觉模型的任务类型定义了不同的结构体(`fastdeploy/vision/common/result.h`)来表达模型预测结果,具体如下表所示
| 结构体 | 文档 | 说明 | 相应模型 |
| :----- | :--- | :---- | :------- |
|:------------------------|:----------------------------------------------|:------------------|:------------------------|
| ClassifyResult | [C++/Python文档](./classification_result.md) | 图像分类返回结果 | ResNet50、MobileNetV3等 |
| SegmentationResult | [C++/Python文档](./segmentation_result.md) | 图像分割返回结果 | PP-HumanSeg、PP-LiteSeg等 |
| DetectionResult | [C++/Python文档](./detection_result.md) | 目标检测返回结果 | PP-YOLOE、YOLOv7系列模型等 |
@@ -12,3 +12,4 @@ FastDeploy根据视觉模型的任务类型定义了不同的结构体(`fastd
| FaceRecognitionResult | [C++/Python文档](./face_recognition_result.md) | 目标检测返回结果 | ArcFace、CosFace系列模型等 |
| MattingResult | [C++/Python文档](./matting_result.md) | 目标检测返回结果 | MODNet系列模型等 |
| OCRResult | [C++/Python文档](./ocr_result.md) | 文本框检测,分类和文本识别返回结果 | OCR系列模型等 |
| MOTResult | [C++/Python文档](./mot_result.md) | 多目标跟踪返回结果 | pptracking系列模型等 |

View File

@@ -0,0 +1,40 @@
# MOTResult 多目标跟踪结果
MOTResult代码定义在`fastdeploy/vision/common/result.h`用于表明多目标跟踪中的检测出来的目标框、目标跟踪id、目标类别和目标置信度。
## C++ 定义
```c++
fastdeploy::vision::MOTResult
```
```c++
struct MOTResult{
// left top right bottom
std::vector<std::array<int, 4>> boxes;
std::vector<int> ids;
std::vector<float> scores;
std::vector<int> class_ids;
void Clear();
std::string Str();
};
```
- **boxes**: 成员变量,表示单帧画面中检测出来的所有目标框坐标,`boxes.size()`表示框的个数每个框以4个float数值依次表示xmin, ymin, xmax, ymax 即左上角和右下角坐标
- **ids**: 成员变量表示单帧画面中所有目标的id其元素个数与`boxes.size()`一致
- **scores**: 成员变量,表示单帧画面检测出来的所有目标置信度,其元素个数与`boxes.size()`一致
- **class_ids**: 成员变量,表示单帧画面出来的所有目标类别,其元素个数与`boxes.size()`一致
- **Clear()**: 成员函数,用于清除结构体中存储的结果
- **Str()**: 成员函数将结构体中的信息以字符串形式输出用于Debug
## Python 定义
```python
fastdeploy.vision.MOTResult
```
- **boxes**(list of list(float)): 成员变量表示单帧画面中检测出来的所有目标框坐标。boxes是一个list其每个元素为一个长度为4的list 表示为一个框每个框以4个float数值依次表示xmin, ymin, xmax, ymax 即左上角和右下角坐标
- **ids**(list of list(float)):成员变量表示单帧画面中所有目标的id其元素个数与`boxes`一致
- **scores**(list of float): 成员变量,表示单帧画面检测出来的所有目标置信度
- **class_ids**(list of int): 成员变量,表示单帧画面出来的所有目标类别

View File

@@ -0,0 +1,14 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
add_executable(infer_demo ${PROJECT_SOURCE_DIR}/infer.cc)
# 添加FastDeploy库依赖
target_link_libraries(infer_demo ${FASTDEPLOY_LIBS})

View File

@@ -0,0 +1,79 @@
# PP-Tracking C++部署示例
本目录下提供`infer.cc`快速完成PP-Tracking在CPU/GPU以及GPU上通过TensorRT加速部署的示例。
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. 根据开发环境下载预编译部署库和samples代码参考[FastDeploy预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
以Linux上 PP-Tracking 推理为例在本目录执行如下命令即可完成编译测试如若只需在CPU上部署可在[Fastdeploy C++预编译库](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md/CPP_prebuilt_libraries.md)下载CPU推理库
```bash
#下载SDK编译模型examples代码SDK中包含了examples代码
wget https://bj.bcebos.com/fastdeploy/release/cpp/fastdeploy-linux-x64-gpu-0.3.0.tgz
tar xvf fastdeploy-linux-x64-gpu-0.3.0.tgz
cd fastdeploy-linux-x64-gpu-0.3.0/examples/vision/tracking/pptracking/cpp/
mkdir build && cd build
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../../../../../../../fastdeploy-linux-x64-gpu-0.3.0
make -j
# 下载PP-Tracking模型文件和测试视频
wget https://bj.bcebos.com/paddlehub/fastdeploy/fairmot_hrnetv2_w18_dlafpn_30e_576x320.tgz
tar -xvf fairmot_hrnetv2_w18_dlafpn_30e_576x320.tgz
wget https://bj.bcebos.com/paddlehub/fastdeploy/person.mp4
wget https://bj.bcebos.com/paddlehub/fastdeploy/person.mp4
# CPU推理
./infer_demo fairmot_hrnetv2_w18_dlafpn_30e_576x320 person.mp4 0
# GPU推理
./infer_demo fairmot_hrnetv2_w18_dlafpn_30e_576x320 person.mp4 1
# GPU上TensorRT推理
./infer_demo fairmot_hrnetv2_w18_dlafpn_30e_576x320 person.mp4 2
```
以上命令只适用于Linux或MacOS, Windows下SDK的使用方式请参考:
- [如何在Windows中使用FastDeploy C++ SDK](../../../../../docs/cn/faq/use_sdk_on_windows.md)
## PP-Tracking C++接口
### PPTracking类
```c++
fastdeploy::vision::tracking::PPTracking(
const string& model_file,
const string& params_file = "",
const string& config_file,
const RuntimeOption& runtime_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE)
```
PP-Tracking模型加载和初始化其中model_file为导出的Paddle模型格式。
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **config_file**(str): 推理部署配置文件
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为Paddle格式
#### Predict函数
> ```c++
> PPTracking::Predict(cv::Mat* im, MOTResult* result)
> ```
>
> 模型预测接口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **im**: 输入图像注意需为HWCBGR格式
> > * **result**: 检测结果包括检测框跟踪id各个框的置信度对象类别idMOTResult说明参考[视觉模型预测结果](../../../../../docs/api/vision_results/)
- [模型介绍](../../)
- [Python部署](../python)
- [视觉模型预测结果](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)

View File

@@ -0,0 +1,158 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void CpuInfer(const std::string& model_dir, const std::string& video_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto model = fastdeploy::vision::tracking::PPTracking(
model_file, params_file, config_file);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
fastdeploy::vision::MOTResult result;
cv::Mat frame;
int frame_id=0;
cv::VideoCapture capture(video_file);
// according to the time of prediction to calculate fps
float fps= 0.0f;
while (capture.read(frame)) {
if (frame.empty()) {
break;
}
if (!model.Predict(&frame, &result)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
// std::cout << result.Str() << std::endl;
cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, fps , frame_id);
cv::imshow("mot",out_img);
cv::waitKey(30);
frame_id++;
}
capture.release();
cv::destroyAllWindows();
}
void GpuInfer(const std::string& model_dir, const std::string& video_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
auto model = fastdeploy::vision::tracking::PPTracking(
model_file, params_file, config_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
fastdeploy::vision::MOTResult result;
cv::Mat frame;
int frame_id=0;
cv::VideoCapture capture(video_file);
// according to the time of prediction to calculate fps
float fps= 0.0f;
while (capture.read(frame)) {
if (frame.empty()) {
break;
}
if (!model.Predict(&frame, &result)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
// std::cout << result.Str() << std::endl;
cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, fps , frame_id);
cv::imshow("mot",out_img);
cv::waitKey(30);
frame_id++;
}
capture.release();
cv::destroyAllWindows();
}
void TrtInfer(const std::string& model_dir, const std::string& video_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
option.UseTrtBackend();
auto model = fastdeploy::vision::tracking::PPTracking(
model_file, params_file, config_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
fastdeploy::vision::MOTResult result;
cv::Mat frame;
int frame_id=0;
cv::VideoCapture capture(video_file);
// according to the time of prediction to calculate fps
float fps= 0.0f;
while (capture.read(frame)) {
if (frame.empty()) {
break;
}
if (!model.Predict(&frame, &result)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
// std::cout << result.Str() << std::endl;
cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, fps , frame_id);
cv::imshow("mot",out_img);
cv::waitKey(30);
frame_id++;
}
capture.release();
cv::destroyAllWindows();
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout
<< "Usage: infer_demo path/to/model_dir path/to/video run_option, "
"e.g ./infer_model ./pptracking_model_dir ./person.mp4 0"
<< std::endl;
std::cout << "The data type of run_option is int, 0: run with cpu; 1: run "
"with gpu; 2: run with gpu and use tensorrt backend."
<< std::endl;
return -1;
}
if (std::atoi(argv[3]) == 0) {
CpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 1) {
GpuInfer(argv[1], argv[2]);
} else if (std::atoi(argv[3]) == 2) {
TrtInfer(argv[1], argv[2]);
}
return 0;
}

View File

@@ -0,0 +1,70 @@
# PP-Tracking Python部署示例
在部署前,需确认以下两个步骤
- 1. 软硬件环境满足要求,参考[FastDeploy环境要求](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
- 2. FastDeploy Python whl包安装参考[FastDeploy Python安装](../../../../../docs/cn/build_and_install/download_prebuilt_libraries.md)
本目录下提供`infer.py`快速完成PP-Tracking在CPU/GPU以及GPU上通过TensorRT加速部署的示例。执行如下脚本即可完成
```bash
#下载部署示例代码
git clone https://github.com/PaddlePaddle/FastDeploy.git
cd FastDeploy/examples/vision/tracking/pptracking/python
# 下载PP-Tracking模型文件和测试视频
wget https://bj.bcebos.com/paddlehub/fastdeploy/fairmot_hrnetv2_w18_dlafpn_30e_576x320.tgz
tar -xvf fairmot_hrnetv2_w18_dlafpn_30e_576x320.tgz
wget https://bj.bcebos.com/paddlehub/fastdeploy/person.mp4
# CPU推理
python infer.py --model fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video person.mp4 --device cpu
# GPU推理
python infer.py --model fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video person.mp4 --device gpu
# GPU上使用TensorRT推理 注意TensorRT推理第一次运行有序列化模型的操作有一定耗时需要耐心等待
python infer.py --model fairmot_hrnetv2_w18_dlafpn_30e_576x320 --video person.mp4 --device gpu --use_trt True
```
## PP-Tracking Python接口
```python
fd.vision.tracking.PPTracking(model_file, params_file, config_file, runtime_option=None, model_format=ModelFormat.PADDLE)
```
PP-Tracking模型加载和初始化其中model_file, params_file以及config_file为训练模型导出的Paddle inference文件具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleSeg/tree/release/2.6/Matting)
**参数**
> * **model_file**(str): 模型文件路径
> * **params_file**(str): 参数文件路径
> * **config_file**(str): 推理部署配置文件
> * **runtime_option**(RuntimeOption): 后端推理配置默认为None即采用默认配置
> * **model_format**(ModelFormat): 模型格式默认为Paddle格式
### predict函数
> ```python
> PPTracking.predict(frame)
> ```
>
> 模型预测结口,输入图像直接输出检测结果。
>
> **参数**
>
> > * **frame**(np.ndarray): 输入数据注意需为HWCBGR格式,frame为视频帧如_,frame=cap.read()得到
> **返回**
>
> > 返回`fastdeploy.vision.MOTResult`结构体,结构体说明参考文档[视觉模型预测结果](../../../../../docs/api/vision_results/)
### 类成员属性
#### 预处理参数
用户可按照自己的实际需求,修改下列预处理参数,从而影响最终的推理和部署效果
## 其它文档
- [PP-Tracking 模型介绍](..)
- [PP-Tracking C++部署](../cpp)
- [模型预测结果说明](../../../../../docs/api/vision_results/)
- [如何切换模型推理后端引擎](../../../../../docs/cn/faq/how_to_change_backend.md)

View File

@@ -0,0 +1,79 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import time
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", required=True, help="Path of PaddleSeg model.")
parser.add_argument(
"--video", type=str, required=True, help="Path of test video file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--use_trt",
type=ast.literal_eval,
default=False,
help="Wether to use tensorrt.")
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.use_trt:
option.use_trt_backend()
return option
args = parse_arguments()
# 配置runtime加载模型
runtime_option = build_option(args)
model_file = os.path.join(args.model, "model.pdmodel")
params_file = os.path.join(args.model, "model.pdiparams")
config_file = os.path.join(args.model, "infer_cfg.yml")
model = fd.vision.tracking.PPTracking(
model_file, params_file, config_file, runtime_option=runtime_option)
# 预测图片分割结果
cap = cv2.VideoCapture(args.video)
frame_id = 0
while True:
start_time = time.time()
frame_id = frame_id+1
_, frame = cap.read()
if frame is None:
break
result = model.predict(frame)
end_time = time.time()
fps = 1.0/(end_time-start_time)
img = fd.vision.vis_mot(frame, result, fps, frame_id)
cv2.imshow("video", img)
cv2.waitKey(30)
cap.release()
cv2.destroyAllWindows()

View File

@@ -48,6 +48,7 @@
#include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h"
#include "fastdeploy/vision/ocr/ppocr/recognizer.h"
#include "fastdeploy/vision/segmentation/ppseg/model.h"
#include "fastdeploy/vision/tracking/pptracking/model.h"
#endif
#include "fastdeploy/vision/visualize/visualize.h"

View File

@@ -148,6 +148,26 @@ void OCRResult::Clear() {
cls_labels.clear();
}
void MOTResult::Clear(){
boxes.clear();
ids.clear();
scores.clear();
class_ids.clear();
}
std::string MOTResult::Str(){
std::string out;
out = "MOTResult:\nall boxes counts: "+std::to_string(boxes.size())+"\n";
out += "[xmin\tymin\txmax\tymax\tid\tscore]\n";
for (size_t i = 0; i < boxes.size(); ++i) {
out = out + "["+ std::to_string(boxes[i][0]) + "\t" +
std::to_string(boxes[i][1]) + "\t" + std::to_string(boxes[i][2]) +
"\t" + std::to_string(boxes[i][3]) + "\t" +
std::to_string(ids[i]) + "\t" + std::to_string(scores[i]) + "]\n";
}
return out;
}
FaceDetectionResult::FaceDetectionResult(const FaceDetectionResult& res) {
boxes.assign(res.boxes.begin(), res.boxes.end());
landmarks.assign(res.landmarks.begin(), res.landmarks.end());

View File

@@ -26,6 +26,7 @@ enum FASTDEPLOY_DECL ResultType {
DETECTION,
SEGMENTATION,
OCR,
MOT,
FACE_DETECTION,
FACE_RECOGNITION,
MATTING,
@@ -154,6 +155,21 @@ struct FASTDEPLOY_DECL OCRResult : public BaseResult {
std::string Str();
};
struct FASTDEPLOY_DECL MOTResult : public BaseResult {
// left top right bottom
std::vector<std::array<int, 4>> boxes;
std::vector<int> ids;
std::vector<float> scores;
std::vector<int> class_ids;
ResultType type = ResultType::MOT;
void Clear();
std::string Str();
};
/*! @brief Face detection result structure for all the face detection models
*/
struct FASTDEPLOY_DECL FaceDetectionResult : public BaseResult {
@@ -268,5 +284,6 @@ struct FASTDEPLOY_DECL MattingResult : public BaseResult {
std::string Str();
};
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,413 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The code is based on:
// https://github.com/gatagat/lap/blob/master/lap/lapjv.cpp
// Ths copyright of gatagat/lap is as follows:
// MIT License
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "fastdeploy/vision/tracking/pptracking/lapjv.h"
namespace fastdeploy {
namespace vision {
namespace tracking {
/** Column-reduction and reduction transfer for a dense cost matrix.
*/
int _ccrrt_dense(
const int n, float *cost[], int *free_rows, int *x, int *y, float *v) {
int n_free_rows;
bool *unique;
for (int i = 0; i < n; i++) {
x[i] = -1;
v[i] = LARGE;
y[i] = 0;
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
const float c = cost[i][j];
if (c < v[j]) {
v[j] = c;
y[j] = i;
}
}
}
NEW(unique, bool, n);
memset(unique, TRUE, n);
{
int j = n;
do {
j--;
const int i = y[j];
if (x[i] < 0) {
x[i] = j;
} else {
unique[i] = FALSE;
y[j] = -1;
}
} while (j > 0);
}
n_free_rows = 0;
for (int i = 0; i < n; i++) {
if (x[i] < 0) {
free_rows[n_free_rows++] = i;
} else if (unique[i]) {
const int j = x[i];
float min = LARGE;
for (int j2 = 0; j2 < n; j2++) {
if (j2 == static_cast<int>(j)) {
continue;
}
const float c = cost[i][j2] - v[j2];
if (c < min) {
min = c;
}
}
v[j] -= min;
}
}
FREE(unique);
return n_free_rows;
}
/** Augmenting row reduction for a dense cost matrix.
*/
int _carr_dense(const int n,
float *cost[],
const int n_free_rows,
int *free_rows,
int *x,
int *y,
float *v) {
int current = 0;
int new_free_rows = 0;
int rr_cnt = 0;
while (current < n_free_rows) {
int i0;
int j1, j2;
float v1, v2, v1_new;
bool v1_lowers;
rr_cnt++;
const int free_i = free_rows[current++];
j1 = 0;
v1 = cost[free_i][0] - v[0];
j2 = -1;
v2 = LARGE;
for (int j = 1; j < n; j++) {
const float c = cost[free_i][j] - v[j];
if (c < v2) {
if (c >= v1) {
v2 = c;
j2 = j;
} else {
v2 = v1;
v1 = c;
j2 = j1;
j1 = j;
}
}
}
i0 = y[j1];
v1_new = v[j1] - (v2 - v1);
v1_lowers = v1_new < v[j1];
if (rr_cnt < current * n) {
if (v1_lowers) {
v[j1] = v1_new;
} else if (i0 >= 0 && j2 >= 0) {
j1 = j2;
i0 = y[j2];
}
if (i0 >= 0) {
if (v1_lowers) {
free_rows[--current] = i0;
} else {
free_rows[new_free_rows++] = i0;
}
}
} else {
if (i0 >= 0) {
free_rows[new_free_rows++] = i0;
}
}
x[free_i] = j1;
y[j1] = free_i;
}
return new_free_rows;
}
/** Find columns with minimum d[j] and put them on the SCAN list.
*/
int _find_dense(const int n, int lo, float *d, int *cols, int *y) {
int hi = lo + 1;
float mind = d[cols[lo]];
for (int k = hi; k < n; k++) {
int j = cols[k];
if (d[j] <= mind) {
if (d[j] < mind) {
hi = lo;
mind = d[j];
}
cols[k] = cols[hi];
cols[hi++] = j;
}
}
return hi;
}
// Scan all columns in TODO starting from arbitrary column in SCAN
// and try to decrease d of the TODO columns using the SCAN column.
int _scan_dense(const int n,
float *cost[],
int *plo,
int *phi,
float *d,
int *cols,
int *pred,
int *y,
float *v) {
int lo = *plo;
int hi = *phi;
float h, cred_ij;
while (lo != hi) {
int j = cols[lo++];
const int i = y[j];
const float mind = d[j];
h = cost[i][j] - v[j] - mind;
// For all columns in TODO
for (int k = hi; k < n; k++) {
j = cols[k];
cred_ij = cost[i][j] - v[j] - h;
if (cred_ij < d[j]) {
d[j] = cred_ij;
pred[j] = i;
if (cred_ij == mind) {
if (y[j] < 0) {
return j;
}
cols[k] = cols[hi];
cols[hi++] = j;
}
}
}
}
*plo = lo;
*phi = hi;
return -1;
}
/** Single iteration of modified Dijkstra shortest path algorithm as explained
* in the JV paper.
*
* This is a dense matrix version.
*
* \return The closest free column index.
*/
int find_path_dense(const int n,
float *cost[],
const int start_i,
int *y,
float *v,
int *pred) {
int lo = 0, hi = 0;
int final_j = -1;
int n_ready = 0;
int *cols;
float *d;
NEW(cols, int, n);
NEW(d, float, n);
for (int i = 0; i < n; i++) {
cols[i] = i;
pred[i] = start_i;
d[i] = cost[start_i][i] - v[i];
}
while (final_j == -1) {
// No columns left on the SCAN list.
if (lo == hi) {
n_ready = lo;
hi = _find_dense(n, lo, d, cols, y);
for (int k = lo; k < hi; k++) {
const int j = cols[k];
if (y[j] < 0) {
final_j = j;
}
}
}
if (final_j == -1) {
final_j = _scan_dense(n, cost, &lo, &hi, d, cols, pred, y, v);
}
}
{
const float mind = d[cols[lo]];
for (int k = 0; k < n_ready; k++) {
const int j = cols[k];
v[j] += d[j] - mind;
}
}
FREE(cols);
FREE(d);
return final_j;
}
/** Augment for a dense cost matrix.
*/
int _ca_dense(const int n,
float *cost[],
const int n_free_rows,
int *free_rows,
int *x,
int *y,
float *v) {
int *pred;
NEW(pred, int, n);
for (int *pfree_i = free_rows; pfree_i < free_rows + n_free_rows; pfree_i++) {
int i = -1, j;
int k = 0;
j = find_path_dense(n, cost, *pfree_i, y, v, pred);
while (i != *pfree_i) {
i = pred[j];
y[j] = i;
SWAP_INDICES(j, x[i]);
k++;
}
}
FREE(pred);
return 0;
}
/** Solve dense sparse LAP.
*/
int lapjv_internal(const cv::Mat &cost,
const bool extend_cost,
const float cost_limit,
int *x,
int *y) {
int n_rows = cost.rows;
int n_cols = cost.cols;
int n;
if (n_rows == n_cols) {
n = n_rows;
} else if (!extend_cost) {
throw std::invalid_argument(
"Square cost array expected. If cost is intentionally non-square, pass "
"extend_cost=True.");
}
// Get extend cost
if (extend_cost || cost_limit < LARGE) {
n = n_rows + n_cols;
}
cv::Mat cost_expand(n, n, CV_32F);
float expand_value;
if (cost_limit < LARGE) {
expand_value = cost_limit / 2;
} else {
double max_v;
minMaxLoc(cost, nullptr, &max_v);
expand_value = static_cast<float>(max_v) + 1.;
}
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
cost_expand.at<float>(i, j) = expand_value;
if (i >= n_rows && j >= n_cols) {
cost_expand.at<float>(i, j) = 0;
} else if (i < n_rows && j < n_cols) {
cost_expand.at<float>(i, j) = cost.at<float>(i, j);
}
}
}
// Convert Mat to pointer array
float **cost_ptr;
NEW(cost_ptr, float *, n);
for (int i = 0; i < n; ++i) {
NEW(cost_ptr[i], float, n);
}
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
cost_ptr[i][j] = cost_expand.at<float>(i, j);
}
}
int ret;
int *free_rows;
float *v;
int *x_c;
int *y_c;
NEW(free_rows, int, n);
NEW(v, float, n);
NEW(x_c, int, n);
NEW(y_c, int, n);
ret = _ccrrt_dense(n, cost_ptr, free_rows, x_c, y_c, v);
int i = 0;
while (ret > 0 && i < 2) {
ret = _carr_dense(n, cost_ptr, ret, free_rows, x_c, y_c, v);
i++;
}
if (ret > 0) {
ret = _ca_dense(n, cost_ptr, ret, free_rows, x_c, y_c, v);
}
FREE(v);
FREE(free_rows);
for (int i = 0; i < n; ++i) {
FREE(cost_ptr[i]);
}
FREE(cost_ptr);
if (ret != 0) {
if (ret == -1) {
throw "Out of memory.";
}
throw "Unknown error (lapjv_internal)";
}
// Get output of x, y, opt
for (int i = 0; i < n; ++i) {
if (i < n_rows) {
x[i] = x_c[i];
if (x[i] >= n_cols) {
x[i] = -1;
}
}
if (i < n_cols) {
y[i] = y_c[i];
if (y[i] >= n_rows) {
y[i] = -1;
}
}
}
FREE(x_c);
FREE(y_c);
return ret;
}
} // namespace tracking
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,66 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The code is based on:
// https://github.com/gatagat/lap/blob/master/lap/lapjv.h
// Ths copyright of gatagat/lap is as follows:
// MIT License
#pragma once
#define LARGE 1000000
#if !defined TRUE
#define TRUE 1
#endif
#if !defined FALSE
#define FALSE 0
#endif
#define NEW(x, t, n) \
if ((x = reinterpret_cast<t *>(malloc(sizeof(t) * (n)))) == 0) { \
return -1; \
}
#define FREE(x) \
if (x != 0) { \
free(x); \
x = 0; \
}
#define SWAP_INDICES(a, b) \
{ \
int_t _temp_index = a; \
a = b; \
b = _temp_index; \
}
#include <opencv2/opencv.hpp>
namespace fastdeploy {
namespace vision {
namespace tracking {
typedef signed int int_t;
typedef unsigned int uint_t;
typedef double cost_t;
typedef char boolean;
typedef enum fp_t { FP_1 = 1, FP_2 = 2, FP_DYNAMIC = 3 } fp_t;
int lapjv_internal(const cv::Mat &cost,
const bool extend_cost,
const float cost_limit,
int *x,
int *y);
} // namespace tracking
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,59 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/tracking/pptracking/letter_box.h"
namespace fastdeploy {
namespace vision {
namespace tracking {
LetterBoxResize::LetterBoxResize(const std::vector<int>& target_size, const std::vector<float>& color){
target_size_=target_size;
color_=color;
}
bool LetterBoxResize::ImplByOpenCV(Mat* mat){
if (mat->Channels() != color_.size()) {
FDERROR << "Pad: Require input channels equals to size of padding value, "
"but now channels = "
<< mat->Channels()
<< ", the size of padding values = " << color_.size() << "."
<< std::endl;
return false;
}
// generate scale_factor
int origin_w = mat->Width();
int origin_h = mat->Height();
int target_h = target_size_[0];
int target_w = target_size_[1];
float ratio_h = static_cast<float>(target_h) / static_cast<float>(origin_h);
float ratio_w = static_cast<float>(target_w) / static_cast<float>(origin_w);
float resize_scale = std::min(ratio_h, ratio_w);
int new_shape_w = std::round(mat->Width() * resize_scale);
int new_shape_h = std::round(mat->Height() * resize_scale);
float padw = (target_size_[1] - new_shape_w) / 2.;
float padh = (target_size_[0] - new_shape_h) / 2.;
int top = std::round(padh - 0.1);
int bottom = std::round(padh + 0.1);
int left = std::round(padw - 0.1);
int right = std::round(padw + 0.1);
Resize::Run(mat,new_shape_w,new_shape_h);
Pad::Run(mat,top,bottom,left,right,color_);
return true;
}
} // namespace tracking
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,36 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/fastdeploy_model.h"
namespace fastdeploy {
namespace vision {
namespace tracking {
class LetterBoxResize: public Processor{
public:
LetterBoxResize(const std::vector<int>& target_size, const std::vector<float>& color);
bool ImplByOpenCV(Mat* mat) override;
std::string Name() override { return "LetterBoxResize"; }
private:
std::vector<int> target_size_;
std::vector<float> color_;
};
} // namespace tracking
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,331 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/vision/tracking/pptracking/model.h"
#include "yaml-cpp/yaml.h"
#include "paddle2onnx/converter.h"
namespace fastdeploy {
namespace vision {
namespace tracking {
PPTracking::PPTracking(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option,
const ModelFormat& model_format){
config_file_=config_file;
valid_cpu_backends = {Backend::PDINFER, Backend::ORT};
valid_gpu_backends = {Backend::PDINFER, Backend::ORT};
runtime_option = custom_option;
runtime_option.model_format = model_format;
runtime_option.model_file = model_file;
runtime_option.params_file = params_file;
initialized = Initialize();
}
bool PPTracking::BuildPreprocessPipelineFromConfig(){
processors_.clear();
YAML::Node cfg;
try {
cfg = YAML::LoadFile(config_file_);
} catch (YAML::BadFile& e) {
FDERROR << "Failed to load yaml file " << config_file_
<< ", maybe you should check this file." << std::endl;
return false;
}
// Get draw_threshold for visualization
if (cfg["draw_threshold"].IsDefined()) {
draw_threshold_ = cfg["draw_threshold"].as<float>();
} else {
FDERROR << "Please set draw_threshold." << std::endl;
return false;
}
// Get config for tracker
if (cfg["tracker"].IsDefined()) {
if (cfg["tracker"]["conf_thres"].IsDefined()) {
conf_thresh_ = cfg["tracker"]["conf_thres"].as<float>();
}
else {
std::cerr << "Please set conf_thres in tracker." << std::endl;
return false;
}
if (cfg["tracker"]["min_box_area"].IsDefined()) {
min_box_area_ = cfg["tracker"]["min_box_area"].as<float>();
}
if (cfg["tracker"]["tracked_thresh"].IsDefined()) {
tracked_thresh_ = cfg["tracker"]["tracked_thresh"].as<float>();
}
}
processors_.push_back(std::make_shared<BGR2RGB>());
for (const auto& op : cfg["Preprocess"]) {
std::string op_name = op["type"].as<std::string>();
if (op_name == "Resize") {
bool keep_ratio = op["keep_ratio"].as<bool>();
auto target_size = op["target_size"].as<std::vector<int>>();
int interp = op["interp"].as<int>();
FDASSERT(target_size.size() == 2,
"Require size of target_size be 2, but now it's %lu.",
target_size.size());
if (!keep_ratio) {
int width = target_size[1];
int height = target_size[0];
processors_.push_back(
std::make_shared<Resize>(width, height, -1.0, -1.0, interp, false));
} else {
int min_target_size = std::min(target_size[0], target_size[1]);
int max_target_size = std::max(target_size[0], target_size[1]);
std::vector<int> max_size;
if (max_target_size > 0) {
max_size.push_back(max_target_size);
max_size.push_back(max_target_size);
}
processors_.push_back(std::make_shared<ResizeByShort>(
min_target_size, interp, true, max_size));
}
}
else if(op_name == "LetterBoxResize"){
auto target_size = op["target_size"].as<std::vector<int>>();
FDASSERT(target_size.size() == 2,"Require size of target_size be 2, but now it's %lu.",
target_size.size());
std::vector<float> color{127.0f,127.0f,127.0f};
if (op["fill_value"].IsDefined()){
color =op["fill_value"].as<std::vector<float>>();
}
processors_.push_back(std::make_shared<LetterBoxResize>(target_size, color));
}
else if (op_name == "NormalizeImage") {
auto mean = op["mean"].as<std::vector<float>>();
auto std = op["std"].as<std::vector<float>>();
bool is_scale = true;
if (op["is_scale"]) {
is_scale = op["is_scale"].as<bool>();
}
std::string norm_type = "mean_std";
if (op["norm_type"]) {
norm_type = op["norm_type"].as<std::string>();
}
if (norm_type != "mean_std") {
std::fill(mean.begin(), mean.end(), 0.0);
std::fill(std.begin(), std.end(), 1.0);
}
processors_.push_back(std::make_shared<Normalize>(mean, std, is_scale));
}
else if (op_name == "Permute") {
// Do nothing, do permute as the last operation
continue;
// processors_.push_back(std::make_shared<HWC2CHW>());
} else if (op_name == "Pad") {
auto size = op["size"].as<std::vector<int>>();
auto value = op["fill_value"].as<std::vector<float>>();
processors_.push_back(std::make_shared<Cast>("float"));
processors_.push_back(
std::make_shared<PadToSize>(size[1], size[0], value));
} else if (op_name == "PadStride") {
auto stride = op["stride"].as<int>();
processors_.push_back(
std::make_shared<StridePad>(stride, std::vector<float>(3, 0)));
} else {
FDERROR << "Unexcepted preprocess operator: " << op_name << "."
<< std::endl;
return false;
}
}
processors_.push_back(std::make_shared<HWC2CHW>());
return true;
}
void PPTracking::GetNmsInfo() {
if (runtime_option.model_format == ModelFormat::PADDLE) {
std::string contents;
if (!ReadBinaryFromFile(runtime_option.model_file, &contents)) {
return;
}
auto reader = paddle2onnx::PaddleReader(contents.c_str(), contents.size());
if (reader.has_nms) {
has_nms_ = true;
background_label = reader.nms_params.background_label;
keep_top_k = reader.nms_params.keep_top_k;
nms_eta = reader.nms_params.nms_eta;
nms_threshold = reader.nms_params.nms_threshold;
score_threshold = reader.nms_params.score_threshold;
nms_top_k = reader.nms_params.nms_top_k;
normalized = reader.nms_params.normalized;
}
}
}
bool PPTracking::Initialize() {
// remove multiclass_nms3 now
// this is a trick operation for ppyoloe while inference on trt
GetNmsInfo();
runtime_option.remove_multiclass_nms_ = true;
runtime_option.custom_op_info_["multiclass_nms3"] = "MultiClassNMS";
if (!BuildPreprocessPipelineFromConfig()) {
FDERROR << "Failed to build preprocess pipeline from configuration file."
<< std::endl;
return false;
}
if (!InitRuntime()) {
FDERROR << "Failed to initialize fastdeploy backend." << std::endl;
return false;
}
if (has_nms_ && runtime_option.backend == Backend::TRT) {
FDINFO << "Detected operator multiclass_nms3 in your model, will replace "
"it with fastdeploy::backend::MultiClassNMS(background_label="
<< background_label << ", keep_top_k=" << keep_top_k
<< ", nms_eta=" << nms_eta << ", nms_threshold=" << nms_threshold
<< ", score_threshold=" << score_threshold
<< ", nms_top_k=" << nms_top_k << ", normalized=" << normalized
<< ")." << std::endl;
has_nms_ = false;
}
// create JDETracker instance
std::unique_ptr<JDETracker> jdeTracker(new JDETracker);
jdeTracker_ = std::move(jdeTracker);
return true;
}
bool PPTracking::Predict(cv::Mat *img, MOTResult *result) {
Mat mat(*img);
std::vector<FDTensor> input_tensors;
if (!Preprocess(&mat, &input_tensors)) {
FDERROR << "Failed to preprocess input image." << std::endl;
return false;
}
std::vector<FDTensor> output_tensors;
if (!Infer(input_tensors, &output_tensors)) {
FDERROR << "Failed to inference." << std::endl;
return false;
}
if (!Postprocess(output_tensors, result)) {
FDERROR << "Failed to post process." << std::endl;
return false;
}
return true;
}
bool PPTracking::Preprocess(Mat* mat, std::vector<FDTensor>* outputs) {
int origin_w = mat->Width();
int origin_h = mat->Height();
for (size_t i = 0; i < processors_.size(); ++i) {
if (!(*(processors_[i].get()))(mat)) {
FDERROR << "Failed to process image data in " << processors_[i]->Name()
<< "." << std::endl;
return false;
}
}
// LetterBoxResize(mat);
// Normalize::Run(mat,mean_,scale_,is_scale_);
// HWC2CHW::Run(mat);
Cast::Run(mat, "float");
outputs->resize(3);
// image_shape
(*outputs)[0].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(0).name);
float* shape = static_cast<float*>((*outputs)[0].MutableData());
shape[0] = mat->Height();
shape[1] = mat->Width();
// image
(*outputs)[1].name = InputInfoOfRuntime(1).name;
mat->ShareWithTensor(&((*outputs)[1]));
(*outputs)[1].ExpandDim(0);
// scale
(*outputs)[2].Allocate({1, 2}, FDDataType::FP32, InputInfoOfRuntime(2).name);
float* scale = static_cast<float*>((*outputs)[2].MutableData());
scale[0] = mat->Height() * 1.0 / origin_h;
scale[1] = mat->Width() * 1.0 / origin_w;
return true;
}
void FilterDets(const float conf_thresh,const cv::Mat& dets,std::vector<int>* index) {
for (int i = 0; i < dets.rows; ++i) {
float score = *dets.ptr<float>(i, 4);
if (score > conf_thresh) {
index->push_back(i);
}
}
}
bool PPTracking::Postprocess(std::vector<FDTensor>& infer_result, MOTResult *result){
auto bbox_shape = infer_result[0].shape;
auto bbox_data = static_cast<float*>(infer_result[0].Data());
auto emb_shape = infer_result[1].shape;
auto emb_data = static_cast<float*>(infer_result[1].Data());
cv::Mat dets(bbox_shape[0], 6, CV_32FC1, bbox_data);
cv::Mat emb(bbox_shape[0], emb_shape[1], CV_32FC1, emb_data);
result->Clear();
std::vector<Track> tracks;
std::vector<int> valid;
FilterDets(conf_thresh_, dets, &valid);
cv::Mat new_dets, new_emb;
for (int i = 0; i < valid.size(); ++i) {
new_dets.push_back(dets.row(valid[i]));
new_emb.push_back(emb.row(valid[i]));
}
jdeTracker_->update(new_dets, new_emb, &tracks);
if (tracks.size() == 0) {
std::array<int ,4> box={int(*dets.ptr<float>(0, 0)),
int(*dets.ptr<float>(0, 1)),
int(*dets.ptr<float>(0, 2)),
int(*dets.ptr<float>(0, 3))};
result->boxes.push_back(box);
result->ids.push_back(1);
result->scores.push_back(*dets.ptr<float>(0, 4));
} else {
std::vector<Track>::iterator titer;
for (titer = tracks.begin(); titer != tracks.end(); ++titer) {
if (titer->score < tracked_thresh_) {
continue;
} else {
float w = titer->ltrb[2] - titer->ltrb[0];
float h = titer->ltrb[3] - titer->ltrb[1];
bool vertical = w / h > 1.6;
float area = w * h;
if (area > min_box_area_ && !vertical) {
std::array<int ,4> box = {
int(titer->ltrb[0]), int(titer->ltrb[1]), int(titer->ltrb[2]), int(titer->ltrb[3])};
result->boxes.push_back(box);
result->ids.push_back(titer->id);
result->scores.push_back(titer->score);
}
}
}
}
return true;
}
} // namespace tracking
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,90 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/vision/common/processors/transform.h"
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/common/result.h"
#include "fastdeploy/vision/tracking/pptracking/tracker.h"
#include "fastdeploy/vision/tracking/pptracking/letter_box.h"
namespace fastdeploy {
namespace vision {
namespace tracking {
class FASTDEPLOY_DECL PPTracking: public FastDeployModel {
public:
/** \brief Set path of model file and configuration file, and the configuration of runtime
*
* \param[in] model_file Path of model file, e.g pptracking/model.pdmodel
* \param[in] params_file Path of parameter file, e.g pptracking/model.pdiparams, if the model format is ONNX, this parameter will be ignored
* \param[in] config_file Path of configuration file for deployment, e.g pptracking/infer_cfg.yml
* \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends`
* \param[in] model_format Model format of the loaded model, default is Paddle format
*/
PPTracking(const std::string& model_file,
const std::string& params_file,
const std::string& config_file,
const RuntimeOption& custom_option = RuntimeOption(),
const ModelFormat& model_format = ModelFormat::PADDLE);
/// Get model's name
std::string ModelName() const override { return "pptracking"; }
/** \brief Predict the detection result for an input image(consecutive)
*
* \param[in] im The input image data which is consecutive frame, comes from imread() or videoCapture.read()
* \param[in] result The output tracking result will be writen to this structure
* \return true if the prediction successed, otherwise false
*/
virtual bool Predict(cv::Mat* img, MOTResult* result);
private:
bool BuildPreprocessPipelineFromConfig();
bool Initialize();
void GetNmsInfo();
bool Preprocess(Mat* img, std::vector<FDTensor>* outputs);
bool Postprocess(std::vector<FDTensor>& infer_result, MOTResult *result);
std::vector<std::shared_ptr<Processor>> processors_;
std::string config_file_;
float draw_threshold_;
float conf_thresh_;
float tracked_thresh_;
float min_box_area_;
bool is_scale_ = true;
std::unique_ptr<JDETracker> jdeTracker_;
// configuration for nms
int64_t background_label = -1;
int64_t keep_top_k = 300;
float nms_eta = 1.0;
float nms_threshold = 0.7;
float score_threshold = 0.01;
int64_t nms_top_k = 10000;
bool normalized = true;
bool has_nms_ = true;
};
} // namespace tracking
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,31 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindPPTracking(pybind11::module &m) {
pybind11::class_<vision::tracking::PPTracking, FastDeployModel>(
m, "PPTracking")
.def(pybind11::init<std::string, std::string, std::string, RuntimeOption,
ModelFormat>())
.def("predict",
[](vision::tracking::PPTracking &self,
pybind11::array &data) {
auto mat = PyArrayToCvMat(data);
vision::MOTResult *res = new vision::MOTResult();
self.Predict(&mat, res);
return res;
});
}
} // namespace fastdeploy

View File

@@ -0,0 +1,305 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The code is based on:
// https://github.com/CnybTseng/JDE/blob/master/platforms/common/jdetracker.cpp
// Ths copyright of CnybTseng/JDE is as follows:
// MIT License
#include <limits.h>
#include <stdio.h>
#include <algorithm>
#include <map>
#include "fastdeploy/vision/tracking/pptracking/lapjv.h"
#include "fastdeploy/vision/tracking/pptracking/tracker.h"
#define mat2vec4f(m) \
cv::Vec4f(*m.ptr<float>(0, 0), \
*m.ptr<float>(0, 1), \
*m.ptr<float>(0, 2), \
*m.ptr<float>(0, 3))
namespace fastdeploy {
namespace vision {
namespace tracking {
static std::map<int, float> chi2inv95 = {{1, 3.841459f},
{2, 5.991465f},
{3, 7.814728f},
{4, 9.487729f},
{5, 11.070498f},
{6, 12.591587f},
{7, 14.067140f},
{8, 15.507313f},
{9, 16.918978f}};
JDETracker::JDETracker()
: timestamp(0), max_lost_time(30), lambda(0.98f), det_thresh(0.3f) {}
bool JDETracker::update(const cv::Mat &dets,
const cv::Mat &emb,
std::vector<Track> *tracks) {
++timestamp;
TrajectoryPool candidates(dets.rows);
for (int i = 0; i < dets.rows; ++i) {
float score = *dets.ptr<float>(i, 1);
const cv::Mat &ltrb_ = dets(cv::Rect(2, i, 4, 1));
cv::Vec4f ltrb = mat2vec4f(ltrb_);
const cv::Mat &embedding = emb(cv::Rect(0, i, emb.cols, 1));
candidates[i] = Trajectory(ltrb, score, embedding);
}
TrajectoryPtrPool tracked_trajectories;
TrajectoryPtrPool unconfirmed_trajectories;
for (size_t i = 0; i < this->tracked_trajectories.size(); ++i) {
if (this->tracked_trajectories[i].is_activated)
tracked_trajectories.push_back(&this->tracked_trajectories[i]);
else
unconfirmed_trajectories.push_back(&this->tracked_trajectories[i]);
}
TrajectoryPtrPool trajectory_pool =
tracked_trajectories + &(this->lost_trajectories);
for (size_t i = 0; i < trajectory_pool.size(); ++i)
trajectory_pool[i]->predict();
Match matches;
std::vector<int> mismatch_row;
std::vector<int> mismatch_col;
cv::Mat cost = motion_distance(trajectory_pool, candidates);
linear_assignment(cost, 0.7f, &matches, &mismatch_row, &mismatch_col);
MatchIterator miter;
TrajectoryPtrPool activated_trajectories;
TrajectoryPtrPool retrieved_trajectories;
for (miter = matches.begin(); miter != matches.end(); miter++) {
Trajectory *pt = trajectory_pool[miter->first];
Trajectory &ct = candidates[miter->second];
if (pt->state == Tracked) {
pt->update(&ct, timestamp);
activated_trajectories.push_back(pt);
} else {
pt->reactivate(&ct, count,timestamp);
retrieved_trajectories.push_back(pt);
}
}
TrajectoryPtrPool next_candidates(mismatch_col.size());
for (size_t i = 0; i < mismatch_col.size(); ++i)
next_candidates[i] = &candidates[mismatch_col[i]];
TrajectoryPtrPool next_trajectory_pool;
for (size_t i = 0; i < mismatch_row.size(); ++i) {
int j = mismatch_row[i];
if (trajectory_pool[j]->state == Tracked)
next_trajectory_pool.push_back(trajectory_pool[j]);
}
cost = iou_distance(next_trajectory_pool, next_candidates);
linear_assignment(cost, 0.5f, &matches, &mismatch_row, &mismatch_col);
for (miter = matches.begin(); miter != matches.end(); miter++) {
Trajectory *pt = next_trajectory_pool[miter->first];
Trajectory *ct = next_candidates[miter->second];
if (pt->state == Tracked) {
pt->update(ct, timestamp);
activated_trajectories.push_back(pt);
} else {
pt->reactivate(ct,count, timestamp);
retrieved_trajectories.push_back(pt);
}
}
TrajectoryPtrPool lost_trajectories;
for (size_t i = 0; i < mismatch_row.size(); ++i) {
Trajectory *pt = next_trajectory_pool[mismatch_row[i]];
if (pt->state != Lost) {
pt->mark_lost();
lost_trajectories.push_back(pt);
}
}
TrajectoryPtrPool nnext_candidates(mismatch_col.size());
for (size_t i = 0; i < mismatch_col.size(); ++i)
nnext_candidates[i] = next_candidates[mismatch_col[i]];
cost = iou_distance(unconfirmed_trajectories, nnext_candidates);
linear_assignment(cost, 0.7f, &matches, &mismatch_row, &mismatch_col);
for (miter = matches.begin(); miter != matches.end(); miter++) {
unconfirmed_trajectories[miter->first]->update(
nnext_candidates[miter->second], timestamp);
activated_trajectories.push_back(unconfirmed_trajectories[miter->first]);
}
TrajectoryPtrPool removed_trajectories;
for (size_t i = 0; i < mismatch_row.size(); ++i) {
unconfirmed_trajectories[mismatch_row[i]]->mark_removed();
removed_trajectories.push_back(unconfirmed_trajectories[mismatch_row[i]]);
}
for (size_t i = 0; i < mismatch_col.size(); ++i) {
if (nnext_candidates[mismatch_col[i]]->score < det_thresh) continue;
nnext_candidates[mismatch_col[i]]->activate(count, timestamp);
activated_trajectories.push_back(nnext_candidates[mismatch_col[i]]);
}
for (size_t i = 0; i < this->lost_trajectories.size(); ++i) {
Trajectory &lt = this->lost_trajectories[i];
if (timestamp - lt.timestamp > max_lost_time) {
lt.mark_removed();
removed_trajectories.push_back(&lt);
}
}
TrajectoryPoolIterator piter;
for (piter = this->tracked_trajectories.begin();
piter != this->tracked_trajectories.end();) {
if (piter->state != Tracked)
piter = this->tracked_trajectories.erase(piter);
else
++piter;
}
this->tracked_trajectories += activated_trajectories;
this->tracked_trajectories += retrieved_trajectories;
this->lost_trajectories -= this->tracked_trajectories;
this->lost_trajectories += lost_trajectories;
this->lost_trajectories -= this->removed_trajectories;
this->removed_trajectories += removed_trajectories;
remove_duplicate_trajectory(&this->tracked_trajectories,
&this->lost_trajectories);
tracks->clear();
for (size_t i = 0; i < this->tracked_trajectories.size(); ++i) {
if (this->tracked_trajectories[i].is_activated) {
Track track = {this->tracked_trajectories[i].id,
this->tracked_trajectories[i].score,
this->tracked_trajectories[i].ltrb};
tracks->push_back(track);
}
}
return 0;
}
cv::Mat JDETracker::motion_distance(const TrajectoryPtrPool &a,
const TrajectoryPool &b) {
if (0 == a.size() || 0 == b.size())
return cv::Mat(a.size(), b.size(), CV_32F);
cv::Mat edists = embedding_distance(a, b);
cv::Mat mdists = mahalanobis_distance(a, b);
cv::Mat fdists = lambda * edists + (1 - lambda) * mdists;
const float gate_thresh = chi2inv95[4];
for (int i = 0; i < fdists.rows; ++i) {
for (int j = 0; j < fdists.cols; ++j) {
if (*mdists.ptr<float>(i, j) > gate_thresh)
*fdists.ptr<float>(i, j) = FLT_MAX;
}
}
return fdists;
}
void JDETracker::linear_assignment(const cv::Mat &cost,
float cost_limit,
Match *matches,
std::vector<int> *mismatch_row,
std::vector<int> *mismatch_col) {
matches->clear();
mismatch_row->clear();
mismatch_col->clear();
if (cost.empty()) {
for (int i = 0; i < cost.rows; ++i) mismatch_row->push_back(i);
for (int i = 0; i < cost.cols; ++i) mismatch_col->push_back(i);
return;
}
float opt = 0;
cv::Mat x(cost.rows, 1, CV_32S);
cv::Mat y(cost.cols, 1, CV_32S);
lapjv_internal(cost,
true,
cost_limit,
reinterpret_cast<int *>(x.data),
reinterpret_cast<int *>(y.data));
for (int i = 0; i < x.rows; ++i) {
int j = *x.ptr<int>(i);
if (j >= 0)
matches->insert({i, j});
else
mismatch_row->push_back(i);
}
for (int i = 0; i < y.rows; ++i) {
int j = *y.ptr<int>(i);
if (j < 0) mismatch_col->push_back(i);
}
return;
}
void JDETracker::remove_duplicate_trajectory(TrajectoryPool *a,
TrajectoryPool *b,
float iou_thresh) {
if (a->size() == 0 || b->size() == 0) return;
cv::Mat dist = iou_distance(*a, *b);
cv::Mat mask = dist < iou_thresh;
std::vector<cv::Point> idx;
cv::findNonZero(mask, idx);
std::vector<int> da;
std::vector<int> db;
for (size_t i = 0; i < idx.size(); ++i) {
int ta = (*a)[idx[i].y].timestamp - (*a)[idx[i].y].starttime;
int tb = (*b)[idx[i].x].timestamp - (*b)[idx[i].x].starttime;
if (ta > tb)
db.push_back(idx[i].x);
else
da.push_back(idx[i].y);
}
int id = 0;
TrajectoryPoolIterator piter;
for (piter = a->begin(); piter != a->end();) {
std::vector<int>::iterator iter = find(da.begin(), da.end(), id++);
if (iter != da.end())
piter = a->erase(piter);
else
++piter;
}
id = 0;
for (piter = b->begin(); piter != b->end();) {
std::vector<int>::iterator iter = find(db.begin(), db.end(), id++);
if (iter != db.end())
piter = b->erase(piter);
else
++piter;
}
}
} // namespace tracking
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,78 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The code is based on:
// https://github.com/CnybTseng/JDE/blob/master/platforms/common/jdetracker.h
// Ths copyright of CnybTseng/JDE is as follows:
// MIT License
#pragma once
#include <map>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "fastdeploy/fastdeploy_model.h"
#include "fastdeploy/vision/tracking/pptracking/trajectory.h"
namespace fastdeploy {
namespace vision {
namespace tracking {
typedef std::map<int, int> Match;
typedef std::map<int, int>::iterator MatchIterator;
struct Track {
int id;
float score;
cv::Vec4f ltrb;
};
class FASTDEPLOY_DECL JDETracker {
public:
JDETracker();
virtual bool update(const cv::Mat &dets,
const cv::Mat &emb,
std::vector<Track> *tracks);
virtual ~JDETracker() {}
private:
cv::Mat motion_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b);
void linear_assignment(const cv::Mat &cost,
float cost_limit,
Match *matches,
std::vector<int> *mismatch_row,
std::vector<int> *mismatch_col);
void remove_duplicate_trajectory(TrajectoryPool *a,
TrajectoryPool *b,
float iou_thresh = 0.15f);
private:
int timestamp;
TrajectoryPool tracked_trajectories;
TrajectoryPool lost_trajectories;
TrajectoryPool removed_trajectories;
int max_lost_time;
float lambda;
float det_thresh;
int count = 0;
};
} // namespace tracking
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,519 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The code is based on:
// https://github.com/CnybTseng/JDE/blob/master/platforms/common/trajectory.cpp
// Ths copyright of CnybTseng/JDE is as follows:
// MIT License
#include "fastdeploy/vision/tracking/pptracking/trajectory.h"
#include <algorithm>
namespace fastdeploy {
namespace vision {
namespace tracking {
void TKalmanFilter::init(const cv::Mat &measurement) {
measurement.copyTo(statePost(cv::Rect(0, 0, 1, 4)));
statePost(cv::Rect(0, 4, 1, 4)).setTo(0);
statePost.copyTo(statePre);
float varpos = 2 * std_weight_position * (*measurement.ptr<float>(3));
varpos *= varpos;
float varvel = 10 * std_weight_velocity * (*measurement.ptr<float>(3));
varvel *= varvel;
errorCovPost.setTo(0);
*errorCovPost.ptr<float>(0, 0) = varpos;
*errorCovPost.ptr<float>(1, 1) = varpos;
*errorCovPost.ptr<float>(2, 2) = 1e-4f;
*errorCovPost.ptr<float>(3, 3) = varpos;
*errorCovPost.ptr<float>(4, 4) = varvel;
*errorCovPost.ptr<float>(5, 5) = varvel;
*errorCovPost.ptr<float>(6, 6) = 1e-10f;
*errorCovPost.ptr<float>(7, 7) = varvel;
errorCovPost.copyTo(errorCovPre);
}
const cv::Mat &TKalmanFilter::predict() {
float varpos = std_weight_position * (*statePre.ptr<float>(3));
varpos *= varpos;
float varvel = std_weight_velocity * (*statePre.ptr<float>(3));
varvel *= varvel;
processNoiseCov.setTo(0);
*processNoiseCov.ptr<float>(0, 0) = varpos;
*processNoiseCov.ptr<float>(1, 1) = varpos;
*processNoiseCov.ptr<float>(2, 2) = 1e-4f;
*processNoiseCov.ptr<float>(3, 3) = varpos;
*processNoiseCov.ptr<float>(4, 4) = varvel;
*processNoiseCov.ptr<float>(5, 5) = varvel;
*processNoiseCov.ptr<float>(6, 6) = 1e-10f;
*processNoiseCov.ptr<float>(7, 7) = varvel;
return cv::KalmanFilter::predict();
}
const cv::Mat &TKalmanFilter::correct(const cv::Mat &measurement) {
float varpos = std_weight_position * (*measurement.ptr<float>(3));
varpos *= varpos;
measurementNoiseCov.setTo(0);
*measurementNoiseCov.ptr<float>(0, 0) = varpos;
*measurementNoiseCov.ptr<float>(1, 1) = varpos;
*measurementNoiseCov.ptr<float>(2, 2) = 1e-2f;
*measurementNoiseCov.ptr<float>(3, 3) = varpos;
return cv::KalmanFilter::correct(measurement);
}
void TKalmanFilter::project(cv::Mat *mean, cv::Mat *covariance) const {
float varpos = std_weight_position * (*statePost.ptr<float>(3));
varpos *= varpos;
cv::Mat measurementNoiseCov_ = cv::Mat::eye(4, 4, CV_32F);
*measurementNoiseCov_.ptr<float>(0, 0) = varpos;
*measurementNoiseCov_.ptr<float>(1, 1) = varpos;
*measurementNoiseCov_.ptr<float>(2, 2) = 1e-2f;
*measurementNoiseCov_.ptr<float>(3, 3) = varpos;
*mean = measurementMatrix * statePost;
cv::Mat temp = measurementMatrix * errorCovPost;
gemm(temp,
measurementMatrix,
1,
measurementNoiseCov_,
1,
*covariance,
cv::GEMM_2_T);
}
const cv::Mat &Trajectory::predict(void) {
if (state != Tracked) *cv::KalmanFilter::statePost.ptr<float>(7) = 0;
return TKalmanFilter::predict();
}
void Trajectory::update(Trajectory *traj,
int timestamp_,
bool update_embedding_) {
timestamp = timestamp_;
++length;
ltrb = traj->ltrb;
xyah = traj->xyah;
TKalmanFilter::correct(cv::Mat(traj->xyah));
state = Tracked;
is_activated = true;
score = traj->score;
if (update_embedding_) update_embedding(traj->current_embedding);
}
void Trajectory::activate(int &cnt,int timestamp_) {
id = next_id(cnt);
TKalmanFilter::init(cv::Mat(xyah));
length = 0;
state = Tracked;
if (timestamp_ == 1) {
is_activated = true;
}
timestamp = timestamp_;
starttime = timestamp_;
}
void Trajectory::reactivate(Trajectory *traj,int &cnt, int timestamp_, bool newid) {
TKalmanFilter::correct(cv::Mat(traj->xyah));
update_embedding(traj->current_embedding);
length = 0;
state = Tracked;
is_activated = true;
timestamp = timestamp_;
if (newid) id = next_id(cnt);
}
void Trajectory::update_embedding(const cv::Mat &embedding) {
current_embedding = embedding / cv::norm(embedding);
if (smooth_embedding.empty()) {
smooth_embedding = current_embedding;
} else {
smooth_embedding = eta * smooth_embedding + (1 - eta) * current_embedding;
}
smooth_embedding = smooth_embedding / cv::norm(smooth_embedding);
}
TrajectoryPool operator+(const TrajectoryPool &a, const TrajectoryPool &b) {
TrajectoryPool sum;
sum.insert(sum.end(), a.begin(), a.end());
std::vector<int> ids(a.size());
for (size_t i = 0; i < a.size(); ++i) ids[i] = a[i].id;
for (size_t i = 0; i < b.size(); ++i) {
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), b[i].id);
if (iter == ids.end()) {
sum.push_back(b[i]);
ids.push_back(b[i].id);
}
}
return sum;
}
TrajectoryPool operator+(const TrajectoryPool &a, const TrajectoryPtrPool &b) {
TrajectoryPool sum;
sum.insert(sum.end(), a.begin(), a.end());
std::vector<int> ids(a.size());
for (size_t i = 0; i < a.size(); ++i) ids[i] = a[i].id;
for (size_t i = 0; i < b.size(); ++i) {
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), b[i]->id);
if (iter == ids.end()) {
sum.push_back(*b[i]);
ids.push_back(b[i]->id);
}
}
return sum;
}
TrajectoryPool &operator+=(TrajectoryPool &a, // NOLINT
const TrajectoryPtrPool &b) {
std::vector<int> ids(a.size());
for (size_t i = 0; i < a.size(); ++i) ids[i] = a[i].id;
for (size_t i = 0; i < b.size(); ++i) {
if (b[i]->smooth_embedding.empty()) continue;
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), b[i]->id);
if (iter == ids.end()) {
a.push_back(*b[i]);
ids.push_back(b[i]->id);
}
}
return a;
}
TrajectoryPool operator-(const TrajectoryPool &a, const TrajectoryPool &b) {
TrajectoryPool dif;
std::vector<int> ids(b.size());
for (size_t i = 0; i < b.size(); ++i) ids[i] = b[i].id;
for (size_t i = 0; i < a.size(); ++i) {
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), a[i].id);
if (iter == ids.end()) dif.push_back(a[i]);
}
return dif;
}
TrajectoryPool &operator-=(TrajectoryPool &a, // NOLINT
const TrajectoryPool &b) {
std::vector<int> ids(b.size());
for (size_t i = 0; i < b.size(); ++i) ids[i] = b[i].id;
TrajectoryPoolIterator piter;
for (piter = a.begin(); piter != a.end();) {
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), piter->id);
if (iter == ids.end())
++piter;
else
piter = a.erase(piter);
}
return a;
}
TrajectoryPtrPool operator+(const TrajectoryPtrPool &a,
const TrajectoryPtrPool &b) {
TrajectoryPtrPool sum;
sum.insert(sum.end(), a.begin(), a.end());
std::vector<int> ids(a.size());
for (size_t i = 0; i < a.size(); ++i) ids[i] = a[i]->id;
for (size_t i = 0; i < b.size(); ++i) {
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), b[i]->id);
if (iter == ids.end()) {
sum.push_back(b[i]);
ids.push_back(b[i]->id);
}
}
return sum;
}
TrajectoryPtrPool operator+(const TrajectoryPtrPool &a, TrajectoryPool *b) {
TrajectoryPtrPool sum;
sum.insert(sum.end(), a.begin(), a.end());
std::vector<int> ids(a.size());
for (size_t i = 0; i < a.size(); ++i) ids[i] = a[i]->id;
for (size_t i = 0; i < b->size(); ++i) {
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), (*b)[i].id);
if (iter == ids.end()) {
sum.push_back(&(*b)[i]);
ids.push_back((*b)[i].id);
}
}
return sum;
}
TrajectoryPtrPool operator-(const TrajectoryPtrPool &a,
const TrajectoryPtrPool &b) {
TrajectoryPtrPool dif;
std::vector<int> ids(b.size());
for (size_t i = 0; i < b.size(); ++i) ids[i] = b[i]->id;
for (size_t i = 0; i < a.size(); ++i) {
std::vector<int>::iterator iter = find(ids.begin(), ids.end(), a[i]->id);
if (iter == ids.end()) dif.push_back(a[i]);
}
return dif;
}
cv::Mat embedding_distance(const TrajectoryPool &a, const TrajectoryPool &b) {
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i) {
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j) {
cv::Mat u = a[i].smooth_embedding;
cv::Mat v = b[j].smooth_embedding;
double uv = u.dot(v);
double uu = u.dot(u);
double vv = v.dot(v);
double dist = std::abs(1. - uv / std::sqrt(uu * vv));
// double dist = cv::norm(a[i].smooth_embedding, b[j].smooth_embedding,
// cv::NORM_L2);
distsi[j] = static_cast<float>(std::max(std::min(dist, 2.), 0.));
}
}
return dists;
}
cv::Mat embedding_distance(const TrajectoryPtrPool &a,
const TrajectoryPtrPool &b) {
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i) {
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j) {
// double dist = cv::norm(a[i]->smooth_embedding, b[j]->smooth_embedding,
// cv::NORM_L2);
// distsi[j] = static_cast<float>(dist);
cv::Mat u = a[i]->smooth_embedding;
cv::Mat v = b[j]->smooth_embedding;
double uv = u.dot(v);
double uu = u.dot(u);
double vv = v.dot(v);
double dist = std::abs(1. - uv / std::sqrt(uu * vv));
distsi[j] = static_cast<float>(std::max(std::min(dist, 2.), 0.));
}
}
return dists;
}
cv::Mat embedding_distance(const TrajectoryPtrPool &a,
const TrajectoryPool &b) {
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i) {
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j) {
// double dist = cv::norm(a[i]->smooth_embedding, b[j].smooth_embedding,
// cv::NORM_L2);
// distsi[j] = static_cast<float>(dist);
cv::Mat u = a[i]->smooth_embedding;
cv::Mat v = b[j].smooth_embedding;
double uv = u.dot(v);
double uu = u.dot(u);
double vv = v.dot(v);
double dist = std::abs(1. - uv / std::sqrt(uu * vv));
distsi[j] = static_cast<float>(std::max(std::min(dist, 2.), 0.));
}
}
return dists;
}
cv::Mat mahalanobis_distance(const TrajectoryPool &a, const TrajectoryPool &b) {
std::vector<cv::Mat> means(a.size());
std::vector<cv::Mat> icovariances(a.size());
for (size_t i = 0; i < a.size(); ++i) {
cv::Mat covariance;
a[i].project(&means[i], &covariance);
cv::invert(covariance, icovariances[i]);
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i) {
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j) {
const cv::Mat x(b[j].xyah);
float dist =
static_cast<float>(cv::Mahalanobis(x, means[i], icovariances[i]));
distsi[j] = dist * dist;
}
}
return dists;
}
cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a,
const TrajectoryPtrPool &b) {
std::vector<cv::Mat> means(a.size());
std::vector<cv::Mat> icovariances(a.size());
for (size_t i = 0; i < a.size(); ++i) {
cv::Mat covariance;
a[i]->project(&means[i], &covariance);
cv::invert(covariance, icovariances[i]);
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i) {
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j) {
const cv::Mat x(b[j]->xyah);
float dist =
static_cast<float>(cv::Mahalanobis(x, means[i], icovariances[i]));
distsi[j] = dist * dist;
}
}
return dists;
}
cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a,
const TrajectoryPool &b) {
std::vector<cv::Mat> means(a.size());
std::vector<cv::Mat> icovariances(a.size());
for (size_t i = 0; i < a.size(); ++i) {
cv::Mat covariance;
a[i]->project(&means[i], &covariance);
cv::invert(covariance, icovariances[i]);
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i) {
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j) {
const cv::Mat x(b[j].xyah);
float dist =
static_cast<float>(cv::Mahalanobis(x, means[i], icovariances[i]));
distsi[j] = dist * dist;
}
}
return dists;
}
static inline float calc_inter_area(const cv::Vec4f &a, const cv::Vec4f &b) {
if (a[2] < b[0] || a[0] > b[2] || a[3] < b[1] || a[1] > b[3]) return 0.f;
float w = std::min(a[2], b[2]) - std::max(a[0], b[0]);
float h = std::min(a[3], b[3]) - std::max(a[1], b[1]);
return w * h;
}
cv::Mat iou_distance(const TrajectoryPool &a, const TrajectoryPool &b) {
std::vector<float> areaa(a.size());
for (size_t i = 0; i < a.size(); ++i) {
float w = a[i].ltrb[2] - a[i].ltrb[0];
float h = a[i].ltrb[3] - a[i].ltrb[1];
areaa[i] = w * h;
}
std::vector<float> areab(b.size());
for (size_t j = 0; j < b.size(); ++j) {
float w = b[j].ltrb[2] - b[j].ltrb[0];
float h = b[j].ltrb[3] - b[j].ltrb[1];
areab[j] = w * h;
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i) {
const cv::Vec4f &boxa = a[i].ltrb;
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j) {
const cv::Vec4f &boxb = b[j].ltrb;
float inters = calc_inter_area(boxa, boxb);
distsi[j] = 1.f - inters / (areaa[i] + areab[j] - inters);
}
}
return dists;
}
cv::Mat iou_distance(const TrajectoryPtrPool &a, const TrajectoryPtrPool &b) {
std::vector<float> areaa(a.size());
for (size_t i = 0; i < a.size(); ++i) {
float w = a[i]->ltrb[2] - a[i]->ltrb[0];
float h = a[i]->ltrb[3] - a[i]->ltrb[1];
areaa[i] = w * h;
}
std::vector<float> areab(b.size());
for (size_t j = 0; j < b.size(); ++j) {
float w = b[j]->ltrb[2] - b[j]->ltrb[0];
float h = b[j]->ltrb[3] - b[j]->ltrb[1];
areab[j] = w * h;
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i) {
const cv::Vec4f &boxa = a[i]->ltrb;
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j) {
const cv::Vec4f &boxb = b[j]->ltrb;
float inters = calc_inter_area(boxa, boxb);
distsi[j] = 1.f - inters / (areaa[i] + areab[j] - inters);
}
}
return dists;
}
cv::Mat iou_distance(const TrajectoryPtrPool &a, const TrajectoryPool &b) {
std::vector<float> areaa(a.size());
for (size_t i = 0; i < a.size(); ++i) {
float w = a[i]->ltrb[2] - a[i]->ltrb[0];
float h = a[i]->ltrb[3] - a[i]->ltrb[1];
areaa[i] = w * h;
}
std::vector<float> areab(b.size());
for (size_t j = 0; j < b.size(); ++j) {
float w = b[j].ltrb[2] - b[j].ltrb[0];
float h = b[j].ltrb[3] - b[j].ltrb[1];
areab[j] = w * h;
}
cv::Mat dists(a.size(), b.size(), CV_32F);
for (size_t i = 0; i < a.size(); ++i) {
const cv::Vec4f &boxa = a[i]->ltrb;
float *distsi = dists.ptr<float>(i);
for (size_t j = 0; j < b.size(); ++j) {
const cv::Vec4f &boxb = b[j].ltrb;
float inters = calc_inter_area(boxa, boxb);
distsi[j] = 1.f - inters / (areaa[i] + areab[j] - inters);
}
}
return dists;
}
} // namespace tracking
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,234 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The code is based on:
// https://github.com/CnybTseng/JDE/blob/master/platforms/common/trajectory.h
// Ths copyright of CnybTseng/JDE is as follows:
// MIT License
#pragma once
#include <vector>
#include "fastdeploy/fastdeploy_model.h"
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include "opencv2/video/tracking.hpp"
namespace fastdeploy {
namespace vision {
namespace tracking {
typedef enum { New = 0, Tracked = 1, Lost = 2, Removed = 3 } TrajectoryState;
class Trajectory;
typedef std::vector<Trajectory> TrajectoryPool;
typedef std::vector<Trajectory>::iterator TrajectoryPoolIterator;
typedef std::vector<Trajectory *> TrajectoryPtrPool;
typedef std::vector<Trajectory *>::iterator TrajectoryPtrPoolIterator;
class FASTDEPLOY_DECL TKalmanFilter : public cv::KalmanFilter {
public:
TKalmanFilter(void);
virtual ~TKalmanFilter(void) {}
virtual void init(const cv::Mat &measurement);
virtual const cv::Mat &predict();
virtual const cv::Mat &correct(const cv::Mat &measurement);
virtual void project(cv::Mat *mean, cv::Mat *covariance) const;
private:
float std_weight_position;
float std_weight_velocity;
};
inline TKalmanFilter::TKalmanFilter(void) : cv::KalmanFilter(8, 4) {
cv::KalmanFilter::transitionMatrix = cv::Mat::eye(8, 8, CV_32F);
for (int i = 0; i < 4; ++i)
cv::KalmanFilter::transitionMatrix.at<float>(i, i + 4) = 1;
cv::KalmanFilter::measurementMatrix = cv::Mat::eye(4, 8, CV_32F);
std_weight_position = 1 / 20.f;
std_weight_velocity = 1 / 160.f;
}
class FASTDEPLOY_DECL Trajectory : public TKalmanFilter {
public:
Trajectory();
Trajectory(const cv::Vec4f &ltrb, float score, const cv::Mat &embedding);
Trajectory(const Trajectory &other);
Trajectory &operator=(const Trajectory &rhs);
virtual ~Trajectory(void) {}
int next_id(int &nt);
virtual const cv::Mat &predict(void);
virtual void update(Trajectory *traj,
int timestamp,
bool update_embedding = true);
virtual void activate(int& cnt, int timestamp);
virtual void reactivate(Trajectory *traj, int & cnt,int timestamp, bool newid = false);
virtual void mark_lost(void);
virtual void mark_removed(void);
friend TrajectoryPool operator+(const TrajectoryPool &a,
const TrajectoryPool &b);
friend TrajectoryPool operator+(const TrajectoryPool &a,
const TrajectoryPtrPool &b);
friend TrajectoryPool &operator+=(TrajectoryPool &a, // NOLINT
const TrajectoryPtrPool &b);
friend TrajectoryPool operator-(const TrajectoryPool &a,
const TrajectoryPool &b);
friend TrajectoryPool &operator-=(TrajectoryPool &a, // NOLINT
const TrajectoryPool &b);
friend TrajectoryPtrPool operator+(const TrajectoryPtrPool &a,
const TrajectoryPtrPool &b);
friend TrajectoryPtrPool operator+(const TrajectoryPtrPool &a,
TrajectoryPool *b);
friend TrajectoryPtrPool operator-(const TrajectoryPtrPool &a,
const TrajectoryPtrPool &b);
friend cv::Mat embedding_distance(const TrajectoryPool &a,
const TrajectoryPool &b);
friend cv::Mat embedding_distance(const TrajectoryPtrPool &a,
const TrajectoryPtrPool &b);
friend cv::Mat embedding_distance(const TrajectoryPtrPool &a,
const TrajectoryPool &b);
friend cv::Mat mahalanobis_distance(const TrajectoryPool &a,
const TrajectoryPool &b);
friend cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a,
const TrajectoryPtrPool &b);
friend cv::Mat mahalanobis_distance(const TrajectoryPtrPool &a,
const TrajectoryPool &b);
friend cv::Mat iou_distance(const TrajectoryPool &a, const TrajectoryPool &b);
friend cv::Mat iou_distance(const TrajectoryPtrPool &a,
const TrajectoryPtrPool &b);
friend cv::Mat iou_distance(const TrajectoryPtrPool &a,
const TrajectoryPool &b);
private:
void update_embedding(const cv::Mat &embedding);
public:
TrajectoryState state;
cv::Vec4f ltrb;
cv::Mat smooth_embedding;
int id;
bool is_activated;
int timestamp;
int starttime;
float score;
private:
// int count=0;
cv::Vec4f xyah;
cv::Mat current_embedding;
float eta;
int length;
};
inline cv::Vec4f ltrb2xyah(const cv::Vec4f &ltrb) {
cv::Vec4f xyah;
xyah[0] = (ltrb[0] + ltrb[2]) * 0.5f;
xyah[1] = (ltrb[1] + ltrb[3]) * 0.5f;
xyah[3] = ltrb[3] - ltrb[1];
xyah[2] = (ltrb[2] - ltrb[0]) / xyah[3];
return xyah;
}
inline Trajectory::Trajectory()
: state(New),
ltrb(cv::Vec4f()),
smooth_embedding(cv::Mat()),
id(0),
is_activated(false),
timestamp(0),
starttime(0),
score(0),
eta(0.9),
length(0) {}
inline Trajectory::Trajectory(const cv::Vec4f &ltrb_,
float score_,
const cv::Mat &embedding)
: state(New),
ltrb(ltrb_),
smooth_embedding(cv::Mat()),
id(0),
is_activated(false),
timestamp(0),
starttime(0),
score(score_),
eta(0.9),
length(0) {
xyah = ltrb2xyah(ltrb);
update_embedding(embedding);
}
inline Trajectory::Trajectory(const Trajectory &other)
: state(other.state),
ltrb(other.ltrb),
id(other.id),
is_activated(other.is_activated),
timestamp(other.timestamp),
starttime(other.starttime),
xyah(other.xyah),
score(other.score),
eta(other.eta),
length(other.length) {
other.smooth_embedding.copyTo(smooth_embedding);
other.current_embedding.copyTo(current_embedding);
// copy state in KalmanFilter
other.statePre.copyTo(cv::KalmanFilter::statePre);
other.statePost.copyTo(cv::KalmanFilter::statePost);
other.errorCovPre.copyTo(cv::KalmanFilter::errorCovPre);
other.errorCovPost.copyTo(cv::KalmanFilter::errorCovPost);
}
inline Trajectory &Trajectory::operator=(const Trajectory &rhs) {
this->state = rhs.state;
this->ltrb = rhs.ltrb;
rhs.smooth_embedding.copyTo(this->smooth_embedding);
this->id = rhs.id;
this->is_activated = rhs.is_activated;
this->timestamp = rhs.timestamp;
this->starttime = rhs.starttime;
this->xyah = rhs.xyah;
this->score = rhs.score;
rhs.current_embedding.copyTo(this->current_embedding);
this->eta = rhs.eta;
this->length = rhs.length;
// copy state in KalmanFilter
rhs.statePre.copyTo(cv::KalmanFilter::statePre);
rhs.statePost.copyTo(cv::KalmanFilter::statePost);
rhs.errorCovPre.copyTo(cv::KalmanFilter::errorCovPre);
rhs.errorCovPost.copyTo(cv::KalmanFilter::errorCovPost);
return *this;
}
inline int Trajectory::next_id(int &cnt) {
++cnt;
return cnt;
}
inline void Trajectory::mark_lost(void) { state = Lost; }
inline void Trajectory::mark_removed(void) { state = Removed; }
} // namespace tracking
} // namespace vision
} // namespace fastdeploy

View File

@@ -0,0 +1,26 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/pybind/main.h"
namespace fastdeploy {
void BindPPTracking(pybind11::module& m);
void BindTracking(pybind11::module& m) {
auto tracking_module =
m.def_submodule("tracking", "object tracking models.");
BindPPTracking(tracking_module);
}
} // namespace fastdeploy

View File

@@ -23,6 +23,7 @@ void BindMatting(pybind11::module& m);
void BindFaceDet(pybind11::module& m);
void BindFaceId(pybind11::module& m);
void BindOcr(pybind11::module& m);
void BindTracking(pybind11::module& m);
void BindKeyPointDetection(pybind11::module& m);
#ifdef ENABLE_VISION_VISUALIZE
void BindVisualize(pybind11::module& m);
@@ -63,6 +64,15 @@ void BindVision(pybind11::module& m) {
.def("__repr__", &vision::OCRResult::Str)
.def("__str__", &vision::OCRResult::Str);
pybind11::class_<vision::MOTResult>(m, "MOTResult")
.def(pybind11::init())
.def_readwrite("boxes", &vision::MOTResult::boxes)
.def_readwrite("ids", &vision::MOTResult::ids)
.def_readwrite("scores", &vision::MOTResult::scores)
.def_readwrite("class_ids", &vision::MOTResult::class_ids)
.def("__repr__", &vision::MOTResult::Str)
.def("__str__", &vision::MOTResult::Str);
pybind11::class_<vision::FaceDetectionResult>(m, "FaceDetectionResult")
.def(pybind11::init())
.def_readwrite("boxes", &vision::FaceDetectionResult::boxes)
@@ -112,6 +122,7 @@ void BindVision(pybind11::module& m) {
BindFaceId(m);
BindMatting(m);
BindOcr(m);
BindTracking(m);
BindKeyPointDetection(m);
#ifdef ENABLE_VISION_VISUALIZE
BindVisualize(m);

View File

@@ -0,0 +1,100 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef ENABLE_VISION_VISUALIZE
#include "fastdeploy/vision/visualize/visualize.h"
#include <iomanip>
namespace fastdeploy {
namespace vision {
cv::Scalar GetMOTBoxColor(int idx) {
idx = idx * 3;
cv::Scalar color = cv::Scalar((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255);
return color;
}
cv::Mat VisMOT(const cv::Mat &img, const MOTResult &results, float fps, int frame_id) {
cv::Mat vis_img = img.clone();
int im_h = img.rows;
int im_w = img.cols;
float text_scale = std::max(1, static_cast<int>(im_w / 1600.));
float text_thickness = 2.;
float line_thickness = std::max(1, static_cast<int>(im_w / 500.));
std::ostringstream oss;
oss << std::setiosflags(std::ios::fixed) << std::setprecision(4);
oss << "frame: " << frame_id << " ";
oss << "fps: " << fps << " ";
oss << "num: " << results.boxes.size();
std::string text = oss.str();
cv::Point origin;
origin.x = 0;
origin.y = static_cast<int>(15 * text_scale);
cv::putText(vis_img,
text,
origin,
cv::FONT_HERSHEY_PLAIN,
text_scale,
cv::Scalar(0, 0, 255),
text_thickness);
for (int i = 0; i < results.boxes.size(); ++i) {
const int obj_id = results.ids[i];
const float score = results.scores[i];
cv::Scalar color = GetMOTBoxColor(obj_id);
cv::Point pt1 = cv::Point(results.boxes[i][0], results.boxes[i][1]);
cv::Point pt2 = cv::Point(results.boxes[i][2], results.boxes[i][3]);
cv::Point id_pt =
cv::Point(results.boxes[i][0], results.boxes[i][1] + 10);
cv::Point score_pt =
cv::Point(results.boxes[i][0], results.boxes[i][1] - 10);
cv::rectangle(vis_img, pt1, pt2, color, line_thickness);
std::ostringstream idoss;
idoss << std::setiosflags(std::ios::fixed) << std::setprecision(4);
idoss << obj_id;
std::string id_text = idoss.str();
cv::putText(vis_img,
id_text,
id_pt,
cv::FONT_HERSHEY_PLAIN,
text_scale,
cv::Scalar(0, 255, 255),
text_thickness);
std::ostringstream soss;
soss << std::setiosflags(std::ios::fixed) << std::setprecision(2);
soss << score;
std::string score_text = soss.str();
cv::putText(vis_img,
score_text,
score_pt,
cv::FONT_HERSHEY_PLAIN,
text_scale,
cv::Scalar(0, 255, 255),
text_thickness);
}
return vis_img;
}
}// namespace vision
} //namespace fastdepoly
#endif

View File

@@ -15,7 +15,6 @@
#ifdef ENABLE_VISION_VISUALIZE
#include "fastdeploy/vision/visualize/visualize.h"
#include "opencv2/imgproc/imgproc.hpp"
namespace fastdeploy {
namespace vision {

View File

@@ -77,6 +77,9 @@ FASTDEPLOY_DECL cv::Mat VisMatting(const cv::Mat& im,
const MattingResult& result,
bool remove_small_connected_area = false);
FASTDEPLOY_DECL cv::Mat VisOcr(const cv::Mat& im, const OCRResult& ocr_result);
FASTDEPLOY_DECL cv::Mat VisMOT(const cv::Mat& img,const MOTResult& results, float fps=0.0, int frame_id=0);
FASTDEPLOY_DECL cv::Mat SwapBackground(
const cv::Mat& im, const cv::Mat& background, const MattingResult& result,
bool remove_small_connected_area = false);

View File

@@ -75,6 +75,14 @@ void BindVisualize(pybind11::module& m) {
vision::Mat(vis_im).ShareWithTensor(&out);
return TensorToPyArray(out);
})
.def("vis_mot",
[](pybind11::array& im_data, vision::MOTResult& result,float fps, int frame_id) {
auto im = PyArrayToCvMat(im_data);
auto vis_im = vision::VisMOT(im, result,fps,frame_id);
FDTensor out;
vision::Mat(vis_im).ShareWithTensor(&out);
return TensorToPyArray(out);
})
.def("vis_matting",
[](pybind11::array& im_data, vision::MattingResult& result,
bool remove_small_connected_area) {
@@ -166,6 +174,14 @@ void BindVisualize(pybind11::module& m) {
vision::Mat(vis_im).ShareWithTensor(&out);
return TensorToPyArray(out);
})
.def_static("vis_mot",
[](pybind11::array& im_data, vision::MOTResult& result,float fps, int frame_id) {
auto im = PyArrayToCvMat(im_data);
auto vis_im = vision::VisMOT(im, result,fps,frame_id);
FDTensor out;
vision::Mat(vis_im).ShareWithTensor(&out);
return TensorToPyArray(out);
})
.def_static("vis_matting_alpha",
[](pybind11::array& im_data, vision::MattingResult& result,
bool remove_small_connected_area) {

View File

@@ -1 +0,0 @@

View File

@@ -16,8 +16,8 @@ from __future__ import absolute_import
from . import detection
from . import classification
from . import segmentation
from . import tracking
from . import keypointdetection
from . import matting
from . import facedet
from . import faceid

View File

@@ -0,0 +1,16 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .pptracking import PPTracking

View File

@@ -0,0 +1,37 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
class PPTracking(FastDeployModel):
def __init__(self,
model_file,
params_file,
config_file,
runtime_option=None,
model_format=ModelFormat.PADDLE):
super(PPTracking, self).__init__(runtime_option)
assert model_format == ModelFormat.PADDLE, "PPTracking model only support model format of ModelFormat.Paddle now."
self._model = C.vision.tracking.PPTracking(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "PPTracking model initialize failed."
def predict(self, input_image):
assert input_image is not None, "The input image data is None."
return self._model.predict(input_image)

View File

@@ -98,3 +98,7 @@ def swap_background(im_data,
def vis_ppocr(im_data, det_result):
return C.vision.vis_ppocr(im_data, det_result)
def vis_mot(im_data, mot_result, fps, frame_id):
return C.vision.vis_mot(im_data, mot_result, fps, frame_id)

View File

@@ -46,7 +46,7 @@ def process_on_linux(current_dir):
if len(items) != 4:
os.remove(os.path.join(root, f))
continue
if items[0].strip() not in ["libopencv_highgui", "libopencv_videoio", "libopencv_imgcodecs", "libopencv_imgproc", "libopencv_core"]:
if items[0].strip() not in ["libopencv_highgui", "libopencv_video", "libopencv_videoio", "libopencv_imgcodecs", "libopencv_imgproc", "libopencv_core"]:
os.remove(os.path.join(root, f))
all_libs_paths = [third_libs_path] + user_specified_dirs

View File

@@ -0,0 +1,89 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fastdeploy as fd
import cv2
import os
import numpy as np
import pickle
def test_pptracking_cpu():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/pptracking.tgz"
input_url = "https://bj.bcebos.com/paddlehub/fastdeploy/person.mp4"
fd.download_and_decompress(model_url, ".")
fd.download(input_url, ".")
model_path = "pptracking/fairmot_hrnetv2_w18_dlafpn_30e_576x320"
# use default backend
runtime_option = fd.RuntimeOption()
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
model = fd.vision.tracking.PPTracking(model_file, params_file, config_file, runtime_option=runtime_option)
cap = cv2.VideoCapture("./person.mp4")
frame_id = 0
while True:
_, frame = cap.read()
if frame is None:
break
result = model.predict(frame)
# compare diff
expect = pickle.load(open("pptracking/frame" + str(frame_id) + ".pkl", "rb"))
diff_boxes = np.fabs(np.array(expect["boxes"]) - np.array(result.boxes))
diff_scores = np.fabs(np.array(expect["scores"]) - np.array(result.scores))
diff = max(diff_boxes.max(), diff_scores.max())
thres = 1e-05
assert diff < thres, "The label diff is %f, which is bigger than %f" % (diff, thres)
frame_id = frame_id + 1
cv2.waitKey(30)
if frame_id >= 10:
cap.release()
cv2.destroyAllWindows()
break
def test_pptracking_gpu():
model_url = "https://bj.bcebos.com/paddlehub/fastdeploy/pptracking.tgz"
input_url = "https://bj.bcebos.com/paddlehub/fastdeploy/person.mp4"
fd.download_and_decompress(model_url, ".")
fd.download(input_url, ".")
model_path = "pptracking/fairmot_hrnetv2_w18_dlafpn_30e_576x320"
runtime_option = fd.RuntimeOption()
runtime_option.use_gpu()
# Not supported trt backend, up to now
# runtime_option.use_trt_backend()
model_file = os.path.join(model_path, "model.pdmodel")
params_file = os.path.join(model_path, "model.pdiparams")
config_file = os.path.join(model_path, "infer_cfg.yml")
model = fd.vision.tracking.PPTracking(model_file, params_file, config_file, runtime_option=runtime_option)
cap = cv2.VideoCapture("./person.mp4")
frame_id = 0
while True:
_, frame = cap.read()
if frame is None:
break
result = model.predict(frame)
# compare diff
expect = pickle.load(open("pptracking/frame" + str(frame_id) + ".pkl", "rb"))
diff_boxes = np.fabs(np.array(expect["boxes"]) - np.array(result.boxes))
diff_scores = np.fabs(np.array(expect["scores"]) - np.array(result.scores))
diff = max(diff_boxes.max(), diff_scores.max())
thres = 1e-05
assert diff < thres, "The label diff is %f, which is bigger than %f" % (diff, thres)
frame_id = frame_id + 1
cv2.waitKey(30)
if frame_id >= 10:
cap.release()
cv2.destroyAllWindows()
break