[Other] Add tests for TIMVX (#1605)

* add tests for timvx

* add mobilenetv1 test

* update code

* fix log info

* update log

* fix test

---------

Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
yeliang2258
2023-04-14 10:31:36 +08:00
committed by GitHub
parent b30f62af36
commit 81fbd54c9d
13 changed files with 762 additions and 1 deletions

View File

@@ -35,7 +35,7 @@ if (DEFINED TARGET_ABI)
if(WITH_TIMVX) if(WITH_TIMVX)
set(PADDLELITE_URL "https://bj.bcebos.com/fastdeploy/third_libs/lite-linux-aarch64-timvx-20230316.tgz") set(PADDLELITE_URL "https://bj.bcebos.com/fastdeploy/third_libs/lite-linux-aarch64-timvx-20230316.tgz")
else() else()
set(PADDLELITE_URL "https://bj.bcebos.com/fastdeploy/third_libs/lite-linux-arm64-20221209.tgz") set(PADDLELITE_URL "https://bj.bcebos.com/fastdeploy/third_libs/lite-linux-arm64-20230316.tgz")
endif() endif()
set(THIRD_PARTY_PATH ${CMAKE_CURRENT_BINARY_DIR}/third_libs) set(THIRD_PARTY_PATH ${CMAKE_CURRENT_BINARY_DIR}/third_libs)
set(OpenCV_DIR ${THIRD_PARTY_PATH}/install/opencv/lib/cmake/opencv4) set(OpenCV_DIR ${THIRD_PARTY_PATH}/install/opencv/lib/cmake/opencv4)

3
tests/CMakeLists.txt Normal file → Executable file
View File

@@ -72,6 +72,9 @@ if(WITH_TESTING)
message(STATUS "") message(STATUS "")
message(STATUS "*************FastDeploy Unittest Summary**********") message(STATUS "*************FastDeploy Unittest Summary**********")
file(GLOB_RECURSE ALL_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/*/test_*.cc) file(GLOB_RECURSE ALL_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/*/test_*.cc)
file(GLOB_RECURSE TIMVX_SRCS ${PROJECT_SOURCE_DIR}/tests/timvx/test_*.cc)
list(REMOVE_ITEM ALL_TEST_SRCS ${TIMVX_SRCS})
if(NOT ENABLE_VISION) if(NOT ENABLE_VISION)
# vision_preprocess and release_task need vision # vision_preprocess and release_task need vision
file(GLOB_RECURSE VISION_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/vision_preprocess/test_*.cc) file(GLOB_RECURSE VISION_TEST_SRCS ${PROJECT_SOURCE_DIR}/tests/vision_preprocess/test_*.cc)

46
tests/timvx/CMakeLists.txt Executable file
View File

@@ -0,0 +1,46 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS})
include_directories(${FastDeploy_INCLUDE_DIRS})
set(CMAKE_INSTALL_PREFIX ${CMAKE_SOURCE_DIR}/build/timvx_tests)
# add test for yolov5
add_executable(test_yolov5 ${PROJECT_SOURCE_DIR}/test_yolov5.cc)
target_link_libraries(test_yolov5 ${FASTDEPLOY_LIBS})
install(TARGETS test_yolov5 DESTINATION ./)
# add test for ppyoloe
add_executable(test_ppyoloe ${PROJECT_SOURCE_DIR}/test_ppyoloe.cc)
target_link_libraries(test_ppyoloe ${FASTDEPLOY_LIBS})
install(TARGETS test_ppyoloe DESTINATION ./)
# add test for paddleclas
add_executable(test_clas ${PROJECT_SOURCE_DIR}/test_clas.cc)
target_link_libraries(test_clas ${FASTDEPLOY_LIBS})
install(TARGETS test_clas DESTINATION ./)
# add test for pp-liteseg
add_executable(test_ppliteseg ${PROJECT_SOURCE_DIR}/test_ppliteseg.cc)
target_link_libraries(test_ppliteseg ${FASTDEPLOY_LIBS})
install(TARGETS test_ppliteseg DESTINATION ./)
install(DIRECTORY models DESTINATION ./)
install(DIRECTORY images DESTINATION ./)
install(DIRECTORY results DESTINATION ./)
file(GLOB RUN_TEST run_test.sh)
install(PROGRAMS ${RUN_TEST} DESTINATION ./)
file(GLOB_RECURSE FASTDEPLOY_LIBS ${FASTDEPLOY_INSTALL_DIR}/lib/lib*.so*)
file(GLOB_RECURSE ALL_LIBS ${FASTDEPLOY_INSTALL_DIR}/third_libs/install/lib*.so*)
list(APPEND ALL_LIBS ${FASTDEPLOY_LIBS})
install(PROGRAMS ${ALL_LIBS} DESTINATION lib)

242
tests/timvx/common.h Executable file
View File

@@ -0,0 +1,242 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <fstream>
#include "fastdeploy/vision.h"
std::vector<std::string> stringSplit(const std::string& str, char delim) {
std::stringstream ss(str);
std::string item;
std::vector<std::string> elems;
while (std::getline(ss, item, delim)) {
if (!item.empty()) {
elems.push_back(item);
}
}
return elems;
}
bool CompareDetResult(const fastdeploy::vision::DetectionResult& res,
const std::string& det_result_file) {
std::ifstream res_str(det_result_file);
if (!res_str.is_open()) {
std::cout<< "Could not open detect result file : "
<< det_result_file <<"\n"<< std::endl;
return false;
}
int obj_num = 0;
while (!res_str.eof()) {
std::string line;
std::getline(res_str, line);
if (line.find("DetectionResult") == line.npos
&& line.find(",") != line.npos ) {
auto strs = stringSplit(line, ',');
if (strs.size() != 6) {
std::cout<< "Failed to parse result file : "
<< det_result_file <<"\n"<< std::endl;
return false;
}
std::vector<float> vals;
for (auto str : strs) {
vals.push_back(atof(str.c_str()));
}
if (abs(res.scores[obj_num] - vals[4]) > 0.3) {
std::cout<< "Score error, the result is: "
<< res.scores[obj_num] << " but the expected is: "
<< vals[4] << std::endl;
return false;
}
if (abs(res.label_ids[obj_num] - vals[5]) > 0) {
std::cout<< "label error, the result is: "
<< res.label_ids[obj_num] << " but the expected is: "
<< vals[5] <<std::endl;
return false;
}
std::array<float, 4> boxes = res.boxes[obj_num++];
for (auto i = 0; i < 4; i++) {
if (abs(boxes[i] - vals[i]) > 5) {
std::cout<< "position error, the result is: "
<< boxes[i] << " but the expected is: " << vals[i] <<std::endl;
return false;
}
}
}
}
return true;
}
bool CompareClsResult(const fastdeploy::vision::ClassifyResult& res,
const std::string& cls_result_file) {
std::ifstream res_str(cls_result_file);
if (!res_str.is_open()) {
std::cout<< "Could not open detect result file : "
<< cls_result_file << "\n" << std::endl;
return false;
}
int obj_num = 0;
while (!res_str.eof()) {
std::string line;
std::getline(res_str, line);
if (line.find("label_ids") != line.npos
&& line.find(":") != line.npos) {
auto strs = stringSplit(line, ':');
if (strs.size() != 2) {
std::cout<< "Failed to parse result file : "
<< cls_result_file <<"\n"<< std::endl;
return false;
}
int32_t label = static_cast<int32_t>(atof(strs[1].c_str()));
if (res.label_ids[obj_num] != label) {
std::cout<< "label error, the result is: "
<< res.label_ids[obj_num] << " but the expected is: "
<< label<< "\n" << std::endl;
return false;
}
} else if (line.find("scores") != line.npos
&& line.find(":") != line.npos) {
auto strs = stringSplit(line, ':');
if (strs.size() != 2) {
std::cout<< "Failed to parse result file : "
<< cls_result_file << "\n" << std::endl;
return false;
}
float score = atof(strs[1].c_str());
if (abs(res.scores[obj_num] - score) > 1e-1) {
std::cout << "score error, the result is: "
<< res.scores[obj_num] << " but the expected is: "
<< score << "\n" << std::endl;
return false;
} else {
obj_num++;
}
} else if (line.size()) {
std::cout << "Unknown File. \n" << std::endl;
return false;
}
}
return true;
}
bool WriteSegResult(const fastdeploy::vision::SegmentationResult& res,
const std::string& seg_result_file) {
std::ofstream res_str(seg_result_file);
if (!res_str.is_open()) {
std::cerr<< "Could not open segmentation result file : "
<< seg_result_file <<" to write.\n"<< std::endl;
return false;
}
std::string out;
out = "";
// save shape
for (auto shape : res.shape) {
out += std::to_string(shape) + ",";
}
out += "\n";
// save label
for (auto label : res.label_map) {
out += std::to_string(label) + ",";
}
out += "\n";
// save score
if (res.contain_score_map) {
for (auto score : res.score_map) {
out += std::to_string(score) + ",";
}
}
res_str << out;
return true;
}
bool CompareSegResult(const fastdeploy::vision::SegmentationResult& res,
const std::string& seg_result_file) {
std::ifstream res_str(seg_result_file);
if (!res_str.is_open()) {
std::cout<< "Could not open detect result file : "
<< seg_result_file <<"\n"<< std::endl;
return false;
}
std::string line;
std::getline(res_str, line);
if (line.find(",") == line.npos) {
std::cout << "Unexpected File." << std::endl;
return false;
}
// check shape diff
auto shape_strs = stringSplit(line, ',');
std::vector<int64_t> shape;
for (auto str : shape_strs) {
shape.push_back(static_cast<int64_t>(atof(str.c_str())));
}
if (shape.size() != res.shape.size()) {
std::cout << "Output shape and expected shape size mismatch, shape size: "
<< res.shape.size() << " expected shape size: "
<< shape.size() << std::endl;
return false;
}
for (auto i = 0; i < res.shape.size(); i++) {
if (res.shape[i] != shape[i]) {
std::cout << "Output Shape and expected shape mismatch, shape: "
<< res.shape[i] << " expected: " << shape[i] << std::endl;
return false;
}
}
std::cout << "Shape check passed!" << std::endl;
std::getline(res_str, line);
if (line.find(",") == line.npos) {
std::cout << "Unexpected File." << std::endl;
return false;
}
// check label
auto label_strs = stringSplit(line, ',');
std::vector<uint8_t> labels;
for (auto str : label_strs) {
labels.push_back(static_cast<uint8_t>(atof(str.c_str())));
}
if (labels.size() != res.label_map.size()) {
std::cout << "Output labels and expected shape size mismatch." << std::endl;
return false;
}
for (auto i = 0; i < res.label_map.size(); i++) {
if (res.label_map[i] != labels[i]) {
std::cout << "Output labels and expected labels mismatch." << std::endl;
return false;
}
}
std::cout << "Label check passed!" << std::endl;
// check score_map
if (res.contain_score_map) {
auto scores_strs = stringSplit(line, ',');
std::vector<float> scores;
for (auto str : scores_strs) {
scores.push_back(static_cast<float>(atof(str.c_str())));
}
if (scores.size() != res.score_map.size()) {
std::cout << "Output scores and expected score_map size mismatch."
<<std::endl;
return false;
}
for (auto i = 0; i < res.score_map.size(); i++) {
if (abs(res.score_map[i] - scores[i]) > 3e-1) {
std::cout << "Output scores and expected scores mismatch."
<< std::endl;
return false;
}
}
}
return true;
}

165
tests/timvx/download_models.py Executable file
View File

@@ -0,0 +1,165 @@
import os
import os.path as osp
import logging
import requests
import shutil
import zipfile
import tarfile
import hashlib
import tqdm
DOWNLOAD_RETRY_LIMIT = 3
def md5check(fullname, md5sum=None):
if md5sum is None:
return True
logging.info("File {} md5 checking...".format(fullname))
md5 = hashlib.md5()
with open(fullname, 'rb') as f:
for chunk in iter(lambda: f.read(4096), b""):
md5.update(chunk)
calc_md5sum = md5.hexdigest()
if calc_md5sum != md5sum:
logging.info("File {} md5 check failed, {}(calc) != "
"{}(base)".format(fullname, calc_md5sum, md5sum))
return False
return True
def move_and_merge_tree(src, dst):
"""
Move src directory to dst, if dst is already exists,
merge src to dst
"""
if not osp.exists(dst):
shutil.move(src, dst)
else:
if not osp.isdir(src):
shutil.move(src, dst)
return
for fp in os.listdir(src):
src_fp = osp.join(src, fp)
dst_fp = osp.join(dst, fp)
if osp.isdir(src_fp):
if osp.isdir(dst_fp):
move_and_merge_tree(src_fp, dst_fp)
else:
shutil.move(src_fp, dst_fp)
elif osp.isfile(src_fp) and \
not osp.isfile(dst_fp):
shutil.move(src_fp, dst_fp)
def download(url, path, rename=None, md5sum=None, show_progress=False):
"""
Download from url, save to path.
url (str): download url
path (str): download to given path
"""
if not osp.exists(path):
os.makedirs(path)
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
if rename is not None:
fullname = osp.join(path, rename)
retry_cnt = 0
while not (osp.exists(fullname) and md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
logging.debug("{} download failed.".format(fname))
raise RuntimeError("Download from {} failed. "
"Retry limit reached".format(url))
logging.info("Downloading {} from {}".format(fname, url))
req = requests.get(url, stream=True)
if req.status_code != 200:
raise RuntimeError("Downloading from {} failed with code "
"{}!".format(url, req.status_code))
# For protecting download interupted, download to
# tmp_fullname firstly, move tmp_fullname to fullname
# after download finished
tmp_fullname = fullname + "_tmp"
total_size = req.headers.get('content-length')
with open(tmp_fullname, 'wb') as f:
if total_size and show_progress:
for chunk in tqdm.tqdm(
req.iter_content(chunk_size=1024),
total=(int(total_size) + 1023) // 1024,
unit='KB'):
f.write(chunk)
else:
for chunk in req.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
logging.debug("{} download completed.".format(fname))
return fullname
def decompress(fname):
"""
Decompress for zip and tar file
"""
logging.info("Decompressing {}...".format(fname))
# For protecting decompressing interupted,
# decompress to fpath_tmp directory firstly, if decompress
# successed, move decompress files to fpath and delete
# fpath_tmp and remove download compress file.
fpath = osp.split(fname)[0]
fpath_tmp = osp.join(fpath, 'tmp')
if osp.isdir(fpath_tmp):
shutil.rmtree(fpath_tmp)
os.makedirs(fpath_tmp)
if fname.find('.tar') >= 0 or fname.find('.tgz') >= 0:
with tarfile.open(fname) as tf:
tf.extractall(path=fpath_tmp)
elif fname.find('.zip') >= 0:
with zipfile.ZipFile(fname) as zf:
zf.extractall(path=fpath_tmp)
else:
raise TypeError("Unsupport compress file type {}".format(fname))
for f in os.listdir(fpath_tmp):
src_dir = osp.join(fpath_tmp, f)
dst_dir = osp.join(fpath, f)
move_and_merge_tree(src_dir, dst_dir)
shutil.rmtree(fpath_tmp)
logging.debug("{} decompressed.".format(fname))
return dst_dir
def download_and_decompress(url, path='.', rename=None):
full_name = download(url, path, rename)
if url.count(".tgz") > 0 or url.count(".tar") > 0 or url.count("zip") > 0:
return decompress(full_name)
return
def unset_env(key):
del os.environ[key]
if __name__ == '__main__':
with open("models_url.txt", "r") as f:
if 'https_proxy' in os.environ or 'http_proxy' in os.environ:
unset_env("https_proxy")
unset_env("http_proxy")
for line in f.readlines():
url = line.strip()
print("Downloading: ", url)
if line.count(".tgz") > 0 or line.count(".tar") > 0 or line.count(
"zip") > 0:
dst_dir = download_and_decompress(url, "./models")
else:
dst_dir = download(url, "./images", None)

8
tests/timvx/models_url.txt Executable file
View File

@@ -0,0 +1,8 @@
https://bj.bcebos.com/paddlehub/fastdeploy/mobilenetv1_ssld_ptq.tar
https://bj.bcebos.com/paddlehub/fastdeploy/resnet50_vd_ptq.tar
https://bj.bcebos.com/fastdeploy/models/yolov5s_ptq_model.tar.gz
https://bj.bcebos.com/fastdeploy/models/ppyoloe_noshare_qat.tar.gz
https://bj.bcebos.com/fastdeploy/models/rk1/ppliteseg.tar.gz
https://gitee.com/paddlepaddle/PaddleClas/raw/release/2.4/deploy/images/ImageNet/ILSVRC2012_val_00000010.jpeg
https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
https://paddleseg.bj.bcebos.com/dygraph/demo/cityscapes_demo.png

4
tests/timvx/prepare.sh Executable file
View File

@@ -0,0 +1,4 @@
python download_models.py
wget https://bj.bcebos.com/fastdeploy/models/results.tar.gz
tar -xf results.tar.gz && rm -rf results.tar.gz

1
tests/timvx/requirements.txt Executable file
View File

@@ -0,0 +1 @@
tqdm

9
tests/timvx/run_test.sh Executable file
View File

@@ -0,0 +1,9 @@
export LD_LIBRARY_PATH=${PWD}/lib
export VIV_VX_ENABLE_GRAPH_TRANSFORM=-pcq:1
export VIV_VX_SET_PER_CHANNEL_ENTROPY=100
./test_clas models/mobilenetv1_ssld_ptq images/ILSVRC2012_val_00000010.jpeg results/mobilenetv1_cls.txt
./test_clas models/resnet50_vd_ptq/ images/ILSVRC2012_val_00000010.jpeg results/resnet50_cls.txt
./test_yolov5 models/yolov5s_ptq_model/ images/000000014439.jpg results/yolov5_result.txt
./test_ppyoloe models/ppyoloe_noshare_qat/ images/000000014439.jpg results/ppyoloe_result.txt
./test_ppliteseg models/ppliteseg images/cityscapes_demo.png results/ppliteseg_result.txt

68
tests/timvx/test_clas.cc Executable file
View File

@@ -0,0 +1,68 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "common.h"
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void InitAndInfer(const std::string& model_dir, const std::string& image_file,
const std::string& cls_result) {
auto model_file = model_dir + sep + "inference.pdmodel";
auto params_file = model_dir + sep + "inference.pdiparams";
auto config_file = model_dir + sep + "inference_cls.yaml";
fastdeploy::vision::EnableFlyCV();
fastdeploy::RuntimeOption option;
option.UseTimVX();
auto model = fastdeploy::vision::classification::PaddleClasModel(
model_file, params_file, config_file, option);
assert(model.Initialized());
auto im = cv::imread(image_file);
fastdeploy::vision::ClassifyResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
if (CompareClsResult(res, cls_result)) {
std::cout << model_dir + " run successfully." << std::endl;
} else {
std::cerr << model_dir + " run failed." << std::endl;
}
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout
<< "Usage: test_clas path/to/quant_model "
"path/to/image "
"e.g ./test_clas ./ResNet50_vd_quant ./test.jpeg resnet50_clas.txt"
<< std::endl;
return -1;
}
std::string model_dir = argv[1];
std::string test_image = argv[2];
std::string cls_result = argv[3];
InitAndInfer(model_dir, test_image, cls_result);
return 0;
}

76
tests/timvx/test_ppliteseg.cc Executable file
View File

@@ -0,0 +1,76 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common.h"
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void InitAndInfer(const std::string& model_dir, const std::string& image_file,
const std::string& seg_result_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "deploy.yaml";
auto subgraph_file = model_dir + sep + "subgraph.txt";
fastdeploy::vision::EnableFlyCV();
fastdeploy::RuntimeOption option;
option.UseTimVX();
option.SetLiteSubgraphPartitionPath(subgraph_file);
auto model = fastdeploy::vision::segmentation::PaddleSegModel(
model_file, params_file, config_file, option);
assert(model.Initialized());
auto im = cv::imread(image_file);
fastdeploy::vision::SegmentationResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
// std::cout << res.Str() << std::endl;
// std::ofstream res_str(seg_result_file);
// if(!WriteSegResult(res, seg_result_file)){
// std::cerr << "Fail to write to " << seg_result_file<<std::endl;
// }
// std::cout<<"file writen"<<std::endl;
if (CompareSegResult(res, seg_result_file)) {
std::cout << model_dir + " run successfully." << std::endl;
} else {
std::cerr << model_dir + " run failed." << std::endl;
}
}
int main(int argc, char* argv[]) {
if (argc < 3) {
std::cout << "Usage: infer_demo path/to/quant_model "
"path/to/image "
"e.g ./infer_demo ./ResNet50_vd_quant ./test.jpeg "
"./ppliteseg_result.txt"
<< std::endl;
return -1;
}
std::string model_dir = argv[1];
std::string test_image = argv[2];
std::string seg_result_file = argv[3];
InitAndInfer(model_dir, test_image, seg_result_file);
return 0;
}

68
tests/timvx/test_ppyoloe.cc Executable file
View File

@@ -0,0 +1,68 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "common.h"
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void InitAndInfer(const std::string& model_dir, const std::string& image_file,
const std::string& det_result_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto subgraph_file = model_dir + sep + "subgraph.txt";
fastdeploy::vision::EnableFlyCV();
fastdeploy::RuntimeOption option;
option.UseTimVX();
option.SetLiteSubgraphPartitionPath(subgraph_file);
auto model = fastdeploy::vision::detection::PPYOLOE(model_file, params_file,
config_file, option);
assert(model.Initialized());
auto im = cv::imread(image_file);
fastdeploy::vision::DetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
if (CompareDetResult(res, det_result_file)) {
std::cout << model_dir + " run successfully." << std::endl;
} else {
std::cerr << model_dir + " run failed." << std::endl;
}
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout
<< "Usage: infer_demo path/to/quant_model "
"path/to/image "
"e.g ./infer_demo ./PPYOLOE_L_quant ./test.jpeg ./ppyoloe_result.txt"
<< std::endl;
return -1;
}
std::string model_dir = argv[1];
std::string test_image = argv[2];
std::string det_result_file = argv[3];
InitAndInfer(model_dir, test_image, det_result_file);
return 0;
}

71
tests/timvx/test_yolov5.cc Executable file
View File

@@ -0,0 +1,71 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <fstream>
#include "common.h"
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
void InferAndCompare(const std::string& model_dir,
const std::string& image_file,
const std::string& det_result) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto subgraph_file = model_dir + sep + "subgraph.txt";
fastdeploy::vision::EnableFlyCV();
fastdeploy::RuntimeOption option;
option.UseTimVX();
option.SetLiteSubgraphPartitionPath(subgraph_file);
auto model = fastdeploy::vision::detection::YOLOv5(
model_file, params_file, option, fastdeploy::ModelFormat::PADDLE);
assert(model.Initialized());
auto im = cv::imread(image_file);
fastdeploy::vision::DetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
if (CompareDetResult(res, det_result)) {
std::cout << model_dir + " run successfully." << std::endl;
} else {
std::cerr << model_dir + " run failed." << std::endl;
}
}
int main(int argc, char* argv[]) {
if (argc < 4) {
std::cout << "Usage: infer_demo path/to/quant_model "
"path/to/image "
"run_option, "
"e.g ./infer_demo ./yolov5s_quant ./000000014439.jpg "
"yolov5_result.txt"
<< std::endl;
return -1;
}
std::string model_dir = argv[1];
std::string test_image = argv[2];
std::string det_result = argv[3];
InferAndCompare(model_dir, test_image, det_result);
return 0;
}