diff --git a/docs/api/vision_results/mot_result.md b/docs/api/vision_results/mot_result.md index 0dd7cda71..4ce8d6cfb 100644 --- a/docs/api/vision_results/mot_result.md +++ b/docs/api/vision_results/mot_result.md @@ -37,4 +37,3 @@ fastdeploy.vision.MOTResult - **ids**(list of list(float)):成员变量,表示单帧画面中所有目标的id,其元素个数与`boxes`一致 - **scores**(list of float): 成员变量,表示单帧画面检测出来的所有目标置信度 - **class_ids**(list of int): 成员变量,表示单帧画面出来的所有目标类别 - diff --git a/examples/vision/README.md b/examples/vision/README.md index ca56edd48..03cdf7f40 100644 --- a/examples/vision/README.md +++ b/examples/vision/README.md @@ -2,16 +2,18 @@ 本目录下提供了各类视觉模型的部署,主要涵盖以下任务类型 -| 任务类型 | 说明 | 预测结果结构体 | -|:-------------- |:----------------------------------- |:-------------------------------------------------------------------------------- | -| Detection | 目标检测,输入图像,检测图像中物体位置,并返回检测框坐标及类别和置信度 | [DetectionResult](../../docs/api/vision_results/detection_result.md) | -| Segmentation | 语义分割,输入图像,给出图像中每个像素的分类及置信度 | [SegmentationResult](../../docs/api/vision_results/segmentation_result.md) | -| Classification | 图像分类,输入图像,给出图像的分类结果和置信度 | [ClassifyResult](../../docs/api/vision_results/classification_result.md) | -| FaceDetection | 人脸检测,输入图像,检测图像中人脸位置,并返回检测框坐标及人脸关键点 | [FaceDetectionResult](../../docs/api/vision_results/face_detection_result.md) | -| KeypointDetection | 关键点检测,输入图像,返回图像中人物行为的各个关键点坐标和置信度 | [KeyPointDetectionResult](../../docs/api/vision_results/keypointdetection_result.md) | -| FaceRecognition | 人脸识别,输入图像,返回可用于相似度计算的人脸特征的embedding | [FaceRecognitionResult](../../docs/api/vision_results/face_recognition_result.md) | -| Matting | 抠图,输入图像,返回图片的前景每个像素点的Alpha值 | [MattingResult](../../docs/api/vision_results/matting_result.md) | -| OCR | 文本框检测,分类,文本框内容识别,输入图像,返回文本框坐标,文本框的方向类别以及框内的文本内容 | [OCRResult](../../docs/api/vision_results/ocr_result.md) | +| 任务类型 | 说明 | 预测结果结构体 | +|:------------------|:------------------------------------------------|:-------------------------------------------------------------------------------------| +| Detection | 目标检测,输入图像,检测图像中物体位置,并返回检测框坐标及类别和置信度 | [DetectionResult](../../docs/api/vision_results/detection_result.md) | +| Segmentation | 语义分割,输入图像,给出图像中每个像素的分类及置信度 | [SegmentationResult](../../docs/api/vision_results/segmentation_result.md) | +| Classification | 图像分类,输入图像,给出图像的分类结果和置信度 | [ClassifyResult](../../docs/api/vision_results/classification_result.md) | +| FaceDetection | 人脸检测,输入图像,检测图像中人脸位置,并返回检测框坐标及人脸关键点 | [FaceDetectionResult](../../docs/api/vision_results/face_detection_result.md) | +| KeypointDetection | 关键点检测,输入图像,返回图像中人物行为的各个关键点坐标和置信度 | [KeyPointDetectionResult](../../docs/api/vision_results/keypointdetection_result.md) | +| FaceRecognition | 人脸识别,输入图像,返回可用于相似度计算的人脸特征的embedding | [FaceRecognitionResult](../../docs/api/vision_results/face_recognition_result.md) | +| Matting | 抠图,输入图像,返回图片的前景每个像素点的Alpha值 | [MattingResult](../../docs/api/vision_results/matting_result.md) | +| OCR | 文本框检测,分类,文本框内容识别,输入图像,返回文本框坐标,文本框的方向类别以及框内的文本内容 | [OCRResult](../../docs/api/vision_results/ocr_result.md) | +| MOT | 多目标跟踪,输入图像,检测图像中物体位置,并返回检测框坐标,对象id及类别置信度 | [MOTResult](../../docs/api/vision_results/mot_result.md) | + ## FastDeploy API设计 视觉模型具有较有统一任务范式,在设计API时(包括C++/Python),FastDeploy将视觉模型的部署拆分为四个步骤 diff --git a/examples/vision/tracking/pptracking/cpp/infer.cc b/examples/vision/tracking/pptracking/cpp/infer.cc index 709159eb4..58b4d4b61 100644 --- a/examples/vision/tracking/pptracking/cpp/infer.cc +++ b/examples/vision/tracking/pptracking/cpp/infer.cc @@ -33,25 +33,29 @@ void CpuInfer(const std::string& model_dir, const std::string& video_file) { } fastdeploy::vision::MOTResult result; + fastdeploy::vision::tracking::TrailRecorder recorder; + // during each prediction, data is inserted into the recorder. As the number of predictions increases, + // the memory will continue to grow. You can cancel the insertion through 'UnbindRecorder'. + // int count = 0; // unbind condition + model.BindRecorder(&recorder); cv::Mat frame; - int frame_id=0; cv::VideoCapture capture(video_file); - // according to the time of prediction to calculate fps - float fps= 0.0f; while (capture.read(frame)) { if (frame.empty()) { - break; + break; } if (!model.Predict(&frame, &result)) { - std::cerr << "Failed to predict." << std::endl; - return; + std::cerr << "Failed to predict." << std::endl; + return; } + // such as adding this code can cancel trail datat bind + // if(count++ == 10) model.UnbindRecorder(); // std::cout << result.Str() << std::endl; - cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, fps , frame_id); + cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, 0.0, &recorder); cv::imshow("mot",out_img); cv::waitKey(30); - frame_id++; } + model.UnbindRecorder(); capture.release(); cv::destroyAllWindows(); } @@ -72,25 +76,29 @@ void GpuInfer(const std::string& model_dir, const std::string& video_file) { } fastdeploy::vision::MOTResult result; + fastdeploy::vision::tracking::TrailRecorder trail_recorder; + // during each prediction, data is inserted into the recorder. As the number of predictions increases, + // the memory will continue to grow. You can cancel the insertion through 'UnbindRecorder'. + // int count = 0; // unbind condition + model.BindRecorder(&trail_recorder); cv::Mat frame; - int frame_id=0; cv::VideoCapture capture(video_file); - // according to the time of prediction to calculate fps - float fps= 0.0f; while (capture.read(frame)) { if (frame.empty()) { - break; + break; } if (!model.Predict(&frame, &result)) { - std::cerr << "Failed to predict." << std::endl; - return; + std::cerr << "Failed to predict." << std::endl; + return; } + // such as adding this code can cancel trail datat bind + //if(count++ == 10) model.UnbindRecorder(); // std::cout << result.Str() << std::endl; - cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, fps , frame_id); + cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, 0.0, &trail_recorder); cv::imshow("mot",out_img); cv::waitKey(30); - frame_id++; } + model.UnbindRecorder(); capture.release(); cv::destroyAllWindows(); } @@ -112,11 +120,13 @@ void TrtInfer(const std::string& model_dir, const std::string& video_file) { } fastdeploy::vision::MOTResult result; + fastdeploy::vision::tracking::TrailRecorder recorder; + //during each prediction, data is inserted into the recorder. As the number of predictions increases, + //the memory will continue to grow. You can cancel the insertion through 'UnbindRecorder'. + // int count = 0; // unbind condition + model.BindRecorder(&recorder); cv::Mat frame; - int frame_id=0; cv::VideoCapture capture(video_file); - // according to the time of prediction to calculate fps - float fps= 0.0f; while (capture.read(frame)) { if (frame.empty()) { break; @@ -125,12 +135,14 @@ void TrtInfer(const std::string& model_dir, const std::string& video_file) { std::cerr << "Failed to predict." << std::endl; return; } + // such as adding this code can cancel trail datat bind + // if(count++ == 10) model.UnbindRecorder(); // std::cout << result.Str() << std::endl; - cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, fps , frame_id); + cv::Mat out_img = fastdeploy::vision::VisMOT(frame, result, 0.0, &recorder); cv::imshow("mot",out_img); cv::waitKey(30); - frame_id++; } + model.UnbindRecorder(); capture.release(); cv::destroyAllWindows(); } diff --git a/examples/vision/tracking/pptracking/python/infer.py b/examples/vision/tracking/pptracking/python/infer.py index 39681e7e5..378d89bc1 100644 --- a/examples/vision/tracking/pptracking/python/infer.py +++ b/examples/vision/tracking/pptracking/python/infer.py @@ -14,7 +14,6 @@ import fastdeploy as fd import cv2 -import time import os @@ -60,20 +59,26 @@ config_file = os.path.join(args.model, "infer_cfg.yml") model = fd.vision.tracking.PPTracking( model_file, params_file, config_file, runtime_option=runtime_option) +# 初始化轨迹记录器 +recorder = fd.vision.tracking.TrailRecorder() +# 绑定记录器 注意:每次预测时,往trail_recorder里面插入数据,随着预测次数的增加,内存会不断地增长, +# 可以通过unbind_recorder()方法来解除绑定 +model.bind_recorder(recorder) # 预测图片分割结果 cap = cv2.VideoCapture(args.video) -frame_id = 0 +# count = 0 while True: - start_time = time.time() - frame_id = frame_id+1 _, frame = cap.read() if frame is None: break result = model.predict(frame) - end_time = time.time() - fps = 1.0/(end_time-start_time) - img = fd.vision.vis_mot(frame, result, fps, frame_id) + # count += 1 + # if count == 10: + # model.unbind_recorder() + img = fd.vision.vis_mot(frame, result, 0.0, recorder) cv2.imshow("video", img) - cv2.waitKey(30) + if cv2.waitKey(30) == ord("q"): + break +model.unbind_recorder() cap.release() cv2.destroyAllWindows() diff --git a/fastdeploy/vision/common/result.h b/fastdeploy/vision/common/result.h index 9e613470f..1acca3140 100755 --- a/fastdeploy/vision/common/result.h +++ b/fastdeploy/vision/common/result.h @@ -14,6 +14,7 @@ #pragma once #include "fastdeploy/fastdeploy_model.h" #include "opencv2/core/core.hpp" +#include namespace fastdeploy { /** \brief All C++ FastDeploy Vision Models APIs are defined inside this namespace @@ -171,6 +172,7 @@ struct FASTDEPLOY_DECL MOTResult : public BaseResult { /** \brief The classify label id for all the tracking object */ std::vector class_ids; + ResultType type = ResultType::MOT; /// Clear MOT result void Clear(); diff --git a/fastdeploy/vision/tracking/pptracking/model.cc b/fastdeploy/vision/tracking/pptracking/model.cc index 97d4e1ab9..0ae550ad2 100644 --- a/fastdeploy/vision/tracking/pptracking/model.cc +++ b/fastdeploy/vision/tracking/pptracking/model.cc @@ -161,9 +161,7 @@ bool PPTracking::Initialize() { return false; } // create JDETracker instance - std::unique_ptr jdeTracker(new JDETracker); - jdeTracker_ = std::move(jdeTracker); - + jdeTracker_ = std::unique_ptr(new JDETracker); return true; } @@ -245,7 +243,6 @@ bool PPTracking::Postprocess(std::vector& infer_result, MOTResult *res cv::Mat dets(bbox_shape[0], 6, CV_32FC1, bbox_data); cv::Mat emb(bbox_shape[0], emb_shape[1], CV_32FC1, emb_data); - result->Clear(); std::vector tracks; std::vector valid; @@ -264,7 +261,6 @@ bool PPTracking::Postprocess(std::vector& infer_result, MOTResult *res result->boxes.push_back(box); result->ids.push_back(1); result->scores.push_back(*dets.ptr(0, 4)); - } else { std::vector::iterator titer; for (titer = tracks.begin(); titer != tracks.end(); ++titer) { @@ -285,9 +281,36 @@ bool PPTracking::Postprocess(std::vector& infer_result, MOTResult *res } } } + if (!is_record_trail_) return true; + int nums = result->boxes.size(); + for (int i=0; iboxes[i][0] + result->boxes[i][2]) / 2; + float center_y = (result->boxes[i][1] + result->boxes[i][3]) / 2; + int id = result->ids[i]; + recorder_->Add(id,{int(center_x), int(center_y)}); + } return true; } +void PPTracking::BindRecorder(TrailRecorder* recorder){ + + recorder_ = recorder; + is_record_trail_ = true; +} + +void PPTracking::UnbindRecorder(){ + + is_record_trail_ = false; + std::map>>::iterator iter; + for(iter = recorder_->records.begin(); iter != recorder_->records.end(); iter++){ + iter->second.clear(); + iter->second.shrink_to_fit(); + } + recorder_->records.clear(); + std::map>>().swap(recorder_->records); + recorder_ = nullptr; +} + } // namespace tracking } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/tracking/pptracking/model.h b/fastdeploy/vision/tracking/pptracking/model.h index dc8f44f9d..3d78d05fb 100755 --- a/fastdeploy/vision/tracking/pptracking/model.h +++ b/fastdeploy/vision/tracking/pptracking/model.h @@ -14,6 +14,7 @@ #pragma once +#include #include "fastdeploy/vision/common/processors/transform.h" #include "fastdeploy/fastdeploy_model.h" #include "fastdeploy/vision/common/result.h" @@ -22,6 +23,21 @@ namespace fastdeploy { namespace vision { namespace tracking { +struct TrailRecorder{ + std::map>> records; + void Add(int id, const std::array& record); +}; + +inline void TrailRecorder::Add(int id, const std::array& record) { + auto iter = records.find(id); + if (iter != records.end()) { + auto trail = records[id]; + trail.push_back(record); + records[id] = trail; + } else { + records[id] = {record}; + } +} class FASTDEPLOY_DECL PPTracking: public FastDeployModel { public: @@ -49,6 +65,14 @@ class FASTDEPLOY_DECL PPTracking: public FastDeployModel { * \return true if the prediction successed, otherwise false */ virtual bool Predict(cv::Mat* img, MOTResult* result); + /** \brief bind tracking trail struct + * + * \param[in] recorder The MOT trail will record the trail of object + */ + void BindRecorder(TrailRecorder* recorder); + /** \brief cancel binding and clear trail information + */ + void UnbindRecorder(); private: bool BuildPreprocessPipelineFromConfig(); @@ -65,8 +89,11 @@ class FASTDEPLOY_DECL PPTracking: public FastDeployModel { float conf_thresh_; float tracked_thresh_; float min_box_area_; + bool is_record_trail_ = false; std::unique_ptr jdeTracker_; + TrailRecorder *recorder_ = nullptr; }; + } // namespace tracking } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/tracking/pptracking/pptracking_pybind.cc b/fastdeploy/vision/tracking/pptracking/pptracking_pybind.cc index d56437ad5..a5638628e 100644 --- a/fastdeploy/vision/tracking/pptracking/pptracking_pybind.cc +++ b/fastdeploy/vision/tracking/pptracking/pptracking_pybind.cc @@ -15,6 +15,11 @@ namespace fastdeploy { void BindPPTracking(pybind11::module &m) { + + pybind11::class_(m, "TrailRecorder") + .def(pybind11::init<>()) + .def_readwrite("records", &vision::tracking::TrailRecorder::records) + .def("add", &vision::tracking::TrailRecorder::Add); pybind11::class_( m, "PPTracking") .def(pybind11::initcurrent_embedding); } -void Trajectory::activate(int &cnt,int timestamp_) { +void Trajectory::activate(int &cnt, int timestamp_) { id = next_id(cnt); TKalmanFilter::init(cv::Mat(xyah)); length = 0; @@ -130,7 +130,7 @@ void Trajectory::activate(int &cnt,int timestamp_) { starttime = timestamp_; } -void Trajectory::reactivate(Trajectory *traj,int &cnt, int timestamp_, bool newid) { +void Trajectory::reactivate(Trajectory *traj, int &cnt, int timestamp_, bool newid) { TKalmanFilter::correct(cv::Mat(traj->xyah)); update_embedding(traj->current_embedding); length = 0; diff --git a/fastdeploy/vision/tracking/pptracking/trajectory.h b/fastdeploy/vision/tracking/pptracking/trajectory.h index a869f8409..793419ce1 100644 --- a/fastdeploy/vision/tracking/pptracking/trajectory.h +++ b/fastdeploy/vision/tracking/pptracking/trajectory.h @@ -74,8 +74,8 @@ class FASTDEPLOY_DECL Trajectory : public TKalmanFilter { virtual void update(Trajectory *traj, int timestamp, bool update_embedding = true); - virtual void activate(int& cnt, int timestamp); - virtual void reactivate(Trajectory *traj, int & cnt,int timestamp, bool newid = false); + virtual void activate(int &cnt, int timestamp); + virtual void reactivate(Trajectory *traj, int &cnt, int timestamp, bool newid = false); virtual void mark_lost(void); virtual void mark_removed(void); diff --git a/fastdeploy/vision/visualize/mot.cc b/fastdeploy/vision/visualize/mot.cc index 9877b2d4e..a04fda8e7 100644 --- a/fastdeploy/vision/visualize/mot.cc +++ b/fastdeploy/vision/visualize/mot.cc @@ -25,73 +25,63 @@ cv::Scalar GetMOTBoxColor(int idx) { return color; } - -cv::Mat VisMOT(const cv::Mat &img, const MOTResult &results, float fps, int frame_id) { - +cv::Mat VisMOT(const cv::Mat &img, const MOTResult &results, + float score_threshold, tracking::TrailRecorder* recorder) { cv::Mat vis_img = img.clone(); int im_h = img.rows; int im_w = img.cols; float text_scale = std::max(1, static_cast(im_w / 1600.)); float text_thickness = 2.; float line_thickness = std::max(1, static_cast(im_w / 500.)); - - std::ostringstream oss; - oss << std::setiosflags(std::ios::fixed) << std::setprecision(4); - oss << "frame: " << frame_id << " "; - oss << "fps: " << fps << " "; - oss << "num: " << results.boxes.size(); - std::string text = oss.str(); - - cv::Point origin; - origin.x = 0; - origin.y = static_cast(15 * text_scale); - cv::putText(vis_img, - text, - origin, - cv::FONT_HERSHEY_PLAIN, - text_scale, - cv::Scalar(0, 0, 255), - text_thickness); - for (int i = 0; i < results.boxes.size(); ++i) { - const int obj_id = results.ids[i]; - const float score = results.scores[i]; + if (results.scores[i] < score_threshold) { + continue; + } + const int obj_id = results.ids[i]; + const float score = results.scores[i]; + cv::Scalar color = GetMOTBoxColor(obj_id); + if (recorder != nullptr){ + int id = results.ids[i]; + auto iter = recorder->records.find(id); + if (iter != recorder->records.end()) { + for (int j = 0; j < iter->second.size(); j++) { + cv::Point center(iter->second[j][0], iter->second[j][1]); + cv::circle(vis_img, center, text_thickness, color); + } + } + } + cv::Point pt1 = cv::Point(results.boxes[i][0], results.boxes[i][1]); + cv::Point pt2 = cv::Point(results.boxes[i][2], results.boxes[i][3]); + cv::Point id_pt = + cv::Point(results.boxes[i][0], results.boxes[i][1] + 10); + cv::Point score_pt = + cv::Point(results.boxes[i][0], results.boxes[i][1] - 10); + cv::rectangle(vis_img, pt1, pt2, color, line_thickness); + std::ostringstream idoss; + idoss << std::setiosflags(std::ios::fixed) << std::setprecision(4); + idoss << obj_id; + std::string id_text = idoss.str(); - cv::Scalar color = GetMOTBoxColor(obj_id); + cv::putText(vis_img, + id_text, + id_pt, + cv::FONT_HERSHEY_PLAIN, + text_scale, + color, + text_thickness); - cv::Point pt1 = cv::Point(results.boxes[i][0], results.boxes[i][1]); - cv::Point pt2 = cv::Point(results.boxes[i][2], results.boxes[i][3]); - cv::Point id_pt = - cv::Point(results.boxes[i][0], results.boxes[i][1] + 10); - cv::Point score_pt = - cv::Point(results.boxes[i][0], results.boxes[i][1] - 10); - cv::rectangle(vis_img, pt1, pt2, color, line_thickness); + std::ostringstream soss; + soss << std::setiosflags(std::ios::fixed) << std::setprecision(2); + soss << score; + std::string score_text = soss.str(); - std::ostringstream idoss; - idoss << std::setiosflags(std::ios::fixed) << std::setprecision(4); - idoss << obj_id; - std::string id_text = idoss.str(); - - cv::putText(vis_img, - id_text, - id_pt, - cv::FONT_HERSHEY_PLAIN, - text_scale, - cv::Scalar(0, 255, 255), - text_thickness); - - std::ostringstream soss; - soss << std::setiosflags(std::ios::fixed) << std::setprecision(2); - soss << score; - std::string score_text = soss.str(); - - cv::putText(vis_img, - score_text, - score_pt, - cv::FONT_HERSHEY_PLAIN, - text_scale, - cv::Scalar(0, 255, 255), - text_thickness); + cv::putText(vis_img, + score_text, + score_pt, + cv::FONT_HERSHEY_PLAIN, + text_scale, + color, + text_thickness); } return vis_img; } diff --git a/fastdeploy/vision/visualize/visualize.h b/fastdeploy/vision/visualize/visualize.h index 2141a2764..d874409d0 100644 --- a/fastdeploy/vision/visualize/visualize.h +++ b/fastdeploy/vision/visualize/visualize.h @@ -17,6 +17,8 @@ #include "fastdeploy/vision/common/result.h" #include "opencv2/imgproc/imgproc.hpp" +#include "fastdeploy/vision/tracking/pptracking/model.h" + namespace fastdeploy { namespace vision { @@ -81,8 +83,9 @@ FASTDEPLOY_DECL cv::Mat VisMatting(const cv::Mat& im, bool remove_small_connected_area = false); FASTDEPLOY_DECL cv::Mat VisOcr(const cv::Mat& im, const OCRResult& ocr_result); -FASTDEPLOY_DECL cv::Mat VisMOT(const cv::Mat& img,const MOTResult& results, float fps=0.0, int frame_id=0); - +FASTDEPLOY_DECL cv::Mat VisMOT(const cv::Mat& img, const MOTResult& results, + float score_threshold = 0.0f, + tracking::TrailRecorder* recorder = nullptr); FASTDEPLOY_DECL cv::Mat SwapBackground( const cv::Mat& im, const cv::Mat& background, const MattingResult& result, bool remove_small_connected_area = false); diff --git a/fastdeploy/vision/visualize/visualize_pybind.cc b/fastdeploy/vision/visualize/visualize_pybind.cc index 8cf8b7165..7633579cc 100644 --- a/fastdeploy/vision/visualize/visualize_pybind.cc +++ b/fastdeploy/vision/visualize/visualize_pybind.cc @@ -86,9 +86,9 @@ void BindVisualize(pybind11::module& m) { return TensorToPyArray(out); }) .def("vis_mot", - [](pybind11::array& im_data, vision::MOTResult& result,float fps, int frame_id) { + [](pybind11::array& im_data, vision::MOTResult& result,float score_threshold, vision::tracking::TrailRecorder record) { auto im = PyArrayToCvMat(im_data); - auto vis_im = vision::VisMOT(im, result,fps,frame_id); + auto vis_im = vision::VisMOT(im, result, score_threshold, &record); FDTensor out; vision::Mat(vis_im).ShareWithTensor(&out); return TensorToPyArray(out); @@ -185,9 +185,10 @@ void BindVisualize(pybind11::module& m) { return TensorToPyArray(out); }) .def_static("vis_mot", - [](pybind11::array& im_data, vision::MOTResult& result,float fps, int frame_id) { + [](pybind11::array& im_data, vision::MOTResult& result,float score_threshold, + vision::tracking::TrailRecorder* record) { auto im = PyArrayToCvMat(im_data); - auto vis_im = vision::VisMOT(im, result,fps,frame_id); + auto vis_im = vision::VisMOT(im, result, score_threshold, record); FDTensor out; vision::Mat(vis_im).ShareWithTensor(&out); return TensorToPyArray(out); diff --git a/python/fastdeploy/vision/tracking/__init__.py b/python/fastdeploy/vision/tracking/__init__.py index 946dfd971..d21c975e9 100644 --- a/python/fastdeploy/vision/tracking/__init__.py +++ b/python/fastdeploy/vision/tracking/__init__.py @@ -12,5 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import - +from ... import c_lib_wrap as C from .pptracking import PPTracking + +try: + TrailRecorder = C.vision.tracking.TrailRecorder +except: + pass diff --git a/python/fastdeploy/vision/tracking/pptracking/__init__.py b/python/fastdeploy/vision/tracking/pptracking/__init__.py index 89ca2a7b0..d26b4ba1f 100644 --- a/python/fastdeploy/vision/tracking/pptracking/__init__.py +++ b/python/fastdeploy/vision/tracking/pptracking/__init__.py @@ -48,3 +48,18 @@ class PPTracking(FastDeployModel): """ assert input_image is not None, "The input image data is None." return self._model.predict(input_image) + + def bind_recorder(self, val): + """ Binding tracking trail + + :param val: (TrailRecorder) trail recorder, which is contained object's id and center point sequence + :return: None + """ + self._model.bind_recorder(val) + + def unbind_recorder(self): + """ cancel binding of tracking trail + + :return: + """ + self._model.unbind_recorder() diff --git a/python/fastdeploy/vision/visualize/__init__.py b/python/fastdeploy/vision/visualize/__init__.py index b7f7c7b14..ddbd8758e 100755 --- a/python/fastdeploy/vision/visualize/__init__.py +++ b/python/fastdeploy/vision/visualize/__init__.py @@ -15,6 +15,7 @@ from __future__ import absolute_import import logging from ... import c_lib_wrap as C +import cv2 def vis_detection(im_data, @@ -106,5 +107,5 @@ def vis_ppocr(im_data, det_result): return C.vision.vis_ppocr(im_data, det_result) -def vis_mot(im_data, mot_result, fps, frame_id): - return C.vision.vis_mot(im_data, mot_result, fps, frame_id) +def vis_mot(im_data, mot_result, score_threshold=0.0, records=None): + return C.vision.vis_mot(im_data, mot_result, score_threshold, records)