Add publish task example into test directory (#239)

* Add publish_task_example

* Update CMakeLists.txt

* Add gflags cmake

* Update release task script

* Delete windows related code in run.sh && add openvino option
This commit is contained in:
huangjianhui
2022-09-21 13:26:54 +08:00
committed by GitHub
parent e7f741292e
commit 463ee0a088
6 changed files with 484 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.12)
# 指定下载解压后的fastdeploy库路径
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
find_package(Threads)
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
include(${CMAKE_CURRENT_SOURCE_DIR}/gflags.cmake)
# 添加FastDeploy依赖头文件
include_directories(${FASTDEPLOY_INCS} ${GFLAGS_INCLUDE_DIR})
add_executable(infer_ppyoloe_demo ${PROJECT_SOURCE_DIR}/infer_ppyoloe.cc)
target_link_libraries(infer_ppyoloe_demo ${FASTDEPLOY_LIBS} gflags ${CMAKE_THREAD_LIBS_INIT})

View File

@@ -0,0 +1,60 @@
import numpy as np
import re
def parse_arguments():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"--gt_path",
type=str,
required=True,
help="Path of ground truth result path.")
parser.add_argument(
"--result_path",
type=str,
required=True,
help="Path of inference result path.")
parser.add_argument(
"--platform", type=str, required=True, help="Testcase platform.")
parser.add_argument(
"--device", type=str, required=True, help="Testcase device.")
args = parser.parse_args()
return args
def convert2numpy(result_file):
result = []
with open(result_file, "r+") as f:
for line in f.readlines():
data = re.findall(r"\d+\.?\d*", line)
if len(data) == 6:
result.append([float(num) for num in data])
return np.array(result)
def write2file(error_file):
with open(error_file, "w+") as f:
from platform import python_version
py_version = python_version()
f.write(args.platform + " " + py_version + " " +
args.result_path.split(".")[0] + "\n")
def check_result(gt_result, infer_result, args):
if len(gt_result) != len(infer_result):
infer_result = infer_result[-len(gt_result):]
diff = np.abs(gt_result - infer_result)
if (diff > 1e-5).all():
print(args.platform, args.device, "diff ", diff)
write2file("result.txt")
else:
print(args.platform, args.device, "No diff")
if __name__ == '__main__':
args = parse_arguments()
gt_numpy = convert2numpy(args.gt_path)
infer_numpy = convert2numpy(args.result_path)
check_result(gt_numpy, infer_numpy, args)

View File

@@ -0,0 +1,76 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
INCLUDE(ExternalProject)
option(GIT_URL "Git URL to clone dependent repos" ${GIT_URL})
if(NOT GIT_URL)
set(GIT_URL "https://github.com")
endif()
SET(GFLAGS_PREFIX_DIR ${CMAKE_CURRENT_SOURCE_DIR}/gflags)
SET(GFLAGS_INSTALL_DIR ${CMAKE_CURRENT_SOURCE_DIR}/install/gflags)
SET(GFLAGS_INCLUDE_DIR "${GFLAGS_INSTALL_DIR}/include" CACHE PATH "gflags include directory." FORCE)
set(GFLAGS_REPOSITORY ${GIT_URL}/gflags/gflags.git)
set(GFLAGS_TAG "v2.2.2")
IF(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/gflags_static.lib" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
ELSE(WIN32)
set(GFLAGS_LIBRARIES "${GFLAGS_INSTALL_DIR}/lib/libgflags.a" CACHE FILEPATH "GFLAGS_LIBRARIES" FORCE)
set(BUILD_COMMAND $(MAKE) --silent)
set(INSTALL_COMMAND $(MAKE) install)
ENDIF(WIN32)
INCLUDE_DIRECTORIES(${GFLAGS_INCLUDE_DIR})
ExternalProject_Add(
extern_gflags
${EXTERNAL_PROJECT_LOG_ARGS}
${SHALLOW_CLONE}
GIT_REPOSITORY ${GFLAGS_REPOSITORY}
GIT_TAG ${GFLAGS_TAG}
PREFIX ${GFLAGS_PREFIX_DIR}
UPDATE_COMMAND ""
BUILD_COMMAND ${BUILD_COMMAND}
INSTALL_COMMAND ${INSTALL_COMMAND}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
-DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
-DBUILD_STATIC_LIBS=ON
-DCMAKE_INSTALL_PREFIX=${GFLAGS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DBUILD_TESTING=OFF
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${GFLAGS_INSTALL_DIR}
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
BUILD_BYPRODUCTS ${GFLAGS_LIBRARIES}
)
ADD_LIBRARY(gflags STATIC IMPORTED GLOBAL)
SET_PROPERTY(TARGET gflags PROPERTY IMPORTED_LOCATION ${GFLAGS_LIBRARIES})
ADD_DEPENDENCIES(gflags extern_gflags)
# On Windows (including MinGW), the Shlwapi library is used by gflags if available.
if (WIN32)
include(CheckIncludeFileCXX)
check_include_file_cxx("shlwapi.h" HAVE_SHLWAPI)
if (HAVE_SHLWAPI)
set_property(GLOBAL PROPERTY OS_DEPENDENCY_MODULES shlwapi.lib)
endif(HAVE_SHLWAPI)
endif (WIN32)

View File

@@ -0,0 +1,148 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include "fastdeploy/vision.h"
#ifdef WIN32
const char sep = '\\';
#else
const char sep = '/';
#endif
DEFINE_string(model_dir, "", "Path of inference model");
DEFINE_string(image_file, "", "Path of input image");
DEFINE_string(device, "CPU",
"Choose the device you want to run, it can be: CPU/GPU, "
"default is CPU.");
DEFINE_string(backend, "default",
"Set inference backend, support one of ['default', 'ort', "
"'paddle', 'trt', 'openvino']");
void CpuInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseCpu();
if (FLAGS_backend == "ort") {
option.UseOrtBackend();
} else if (FLAGS_backend == "paddle") {
option.UsePaddleBackend();
} else if (FLAGS_backend == "trt") {
std::cerr << "Use --backend=trt for inference must set --device=gpu"
<< std::endl;
return;
} else if (FLAGS_backend == "openvino") {
option.UseOpenVINOBackend();
} else if (FLAGS_backend == "default") {
std::cout << "Use default backend for inference" << std::endl;
} else {
std::cerr << "Don't support backend type: " + FLAGS_backend << std::endl;
return;
}
auto model = fastdeploy::vision::detection::PPYOLOE(model_file, params_file,
config_file);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
void GpuInfer(const std::string& model_dir, const std::string& image_file) {
auto model_file = model_dir + sep + "model.pdmodel";
auto params_file = model_dir + sep + "model.pdiparams";
auto config_file = model_dir + sep + "infer_cfg.yml";
auto option = fastdeploy::RuntimeOption();
option.UseGpu();
if (FLAGS_backend == "ort") {
option.UseOrtBackend();
} else if (FLAGS_backend == "paddle") {
option.UsePaddleBackend();
} else if (FLAGS_backend == "trt") {
option.UseTrtBackend();
option.SetTrtInputShape("image", {1, 3, 640, 640});
option.SetTrtInputShape("scale_factor", {1, 2});
} else if (FLAGS_backend == "openvino") {
std::cerr << "Use --backend=openvino for inference must set --device=cpu"
<< std::endl;
return
} else if (FLAGS_backend == "default") {
std::cout << "Use default backend for inference" << std::endl;
} else {
std::cerr << "Don't support backend type: " + FLAGS_backend << std::endl;
return;
}
auto model = fastdeploy::vision::detection::PPYOLOE(model_file, params_file,
config_file, option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return;
}
auto im = cv::imread(image_file);
auto im_bak = im.clone();
fastdeploy::vision::DetectionResult res;
if (!model.Predict(&im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return;
}
std::cout << res.Str() << std::endl;
auto vis_im = fastdeploy::vision::Visualize::VisDetection(im_bak, res, 0.5);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
int main(int argc, char* argv[]) {
// Parsing command-line
google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir.empty() || FLAGS_image_file.empty()) {
std::cout << "Usage: infer_demo --model_dir=/path/to/model_dir "
"--image_file=/path/to/image --device=device, "
"e.g ./infer_model --model_dir=./ppyoloe_model_dir "
"--image_file=./test.jpeg --device=cpu"
<< std::endl;
std::cout << "For more information, use ./infer_model --help" << std::endl;
return -1;
}
if (FLAGS_device == "cpu") {
CpuInfer(FLAGS_model_dir, FLAGS_image_file);
} else if (FLAGS_device == "gpu") {
GpuInfer(FLAGS_model_dir, FLAGS_image_file);
} else {
std::cerr << "Don't support device type:" + FLAGS_device << std::endl;
return -1;
}
return 0;
}

View File

@@ -0,0 +1,79 @@
import fastdeploy as fd
import cv2
import os
def parse_arguments():
import argparse
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_dir",
required=True,
help="Path of PaddleDetection model directory")
parser.add_argument(
"--image", required=True, help="Path of test image file.")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--backend",
nargs='?',
type=str,
default='default',
help="Set inference backend, support one of ['default', 'ort', 'paddle', 'trt', 'openvino']."
)
return parser.parse_args()
def build_option(args):
option = fd.RuntimeOption()
if args.device.lower() == "gpu":
option.use_gpu()
if args.backend == "ort":
option.use_ort_backend()
elif args.backend == "paddle":
option.use_paddle_backend()
elif args.backend == "trt":
assert args.device.lower(
) == "gpu", "Set trt backend must use gpu for inference"
option.use_trt_backend()
option.set_trt_input_shape("image", [1, 3, 640, 640])
option.set_trt_input_shape("scale_factor", [1, 2])
elif args.backend == 'openvino':
assert args.device.lower(
) == "cpu", "Set openvino backend must use cpu for inference"
option.use_openvino_backend()
elif args.backend == "default":
pass
else:
raise Exception(
"Don't support backend type: {}, please use one of ['default', 'ort', 'paddle', 'trt'].".
format(args.backend))
return option
args = parse_arguments()
model_file = os.path.join(args.model_dir, "model.pdmodel")
params_file = os.path.join(args.model_dir, "model.pdiparams")
config_file = os.path.join(args.model_dir, "infer_cfg.yml")
# 配置runtime加载模型
runtime_option = build_option(args)
model = fd.vision.detection.PPYOLOE(
model_file, params_file, config_file, runtime_option=runtime_option)
# 预测图片检测结果
im = cv2.imread(args.image)
result = model.predict(im.copy())
print(result)
# 预测结果可视化
vis_im = fd.vision.vis_detection(im, result, score_threshold=0.5)
cv2.imwrite("visualized_result.jpg", vis_im)
print("Visualized result save in ./visualized_result.jpg")

105
tests/release_task/run.sh Normal file
View File

@@ -0,0 +1,105 @@
#!/bin/bash
export no_proxy=bcebos.com
CURRENT_DIR=$(cd $(dirname $0); pwd)
PLATFORM=$1
DEVICE=$2
VERSION=$3
if [ "$DEVICE" = "gpu" ];then
PY_FASTDEPLOY_PACKAGE=fastdeploy-$DEVICE-python
CPP_FASTDEPLOY_PACKAGE=fastdeploy-$PLATFORM-$DEVICE-$VERSION
else
PY_FASTDEPLOY_PACKAGE=fastdeploy-python
CPP_FASTDEPLOY_PACKAGE=fastdeploy-$PLATFORM-$VERSION
fi
echo $CPP_FASTDEPLOY_PACKAGE
echo $PY_FASTDEPLOY_PACKAGE
PY_VERSION_CASE=('python3.6' 'python3.7' 'python3.8' 'python3.9' 'python3.10')
LINUX_X64_GPU_CASE=('ort' 'paddle' 'trt')
LINUX_X64_CPU_CASE=('ort' 'paddle' 'openvino')
LINUX_AARCH_CPU_CASE=('ort' 'openvino')
MACOS_INTEL_CPU_CASE=('ort' 'paddle' 'openvino')
MACOS_ARM64_CPU_CASE=('default')
wget -q https://bj.bcebos.com/paddlehub/fastdeploy/ppyoloe_crn_l_300e_coco.tgz
wget -q https://gitee.com/paddlepaddle/PaddleDetection/raw/release/2.4/demo/000000014439.jpg
wget -q https://bj.bcebos.com/paddlehub/fastdeploy/release_task_groud_truth_result.txt
tar -xvf ppyoloe_crn_l_300e_coco.tgz
IMAGE_PATH=$CURRENT_DIR/000000014439.jpg
MODEL_PATH=$CURRENT_DIR/ppyoloe_crn_l_300e_coco
GROUND_TRUTH_PATH=$CURRENT_DIR/release_task_groud_truth_result.txt
COMPARE_SHELL=$CURRENT_DIR/compare_with_gt.py
RUN_CASE=()
if [ "$DEVICE" = "gpu" ] && [ "$PLATFORM" = "linux-x64" ];then
RUN_CASE=(${LINUX_X64_GPU_CASE[*]})
elif [ "$DEVICE" = "cpu" ] && [ "$PLATFORM" = "linux-x64" ];then
RUN_CASE=(${LINUX_X64_CPU_CASE[*]})
elif [ "$DEVICE" = "cpu" ] && [ "$PLATFORM" = "linux-aarch64" ];then
RUN_CASE=(${LINUX_AARCH_CPU_CASE[*]})
elif [ "$DEVICE" = "cpu" ] && [ "$PLATFORM" = "osx-x86_64" ];then
RUN_CASE=(${MACOS_INTEL_CPU_CASE[*]})
elif [ "$DEVICE" = "cpu" ] && [ "$PLATFORM" = "osx-arm64" ];then
RUN_CASE=(${MACOS_ARM64_CPU_CASE[*]})
fi
py_version_case_number=${#PY_VERSION_CASE[@]}
case_number=${#RUN_CASE[@]}
for((i=0;i<py_version_case_number;i+=1))
do
py_version=${PY_VERSION_CASE[i]}
echo "py_version:" $py_version
$py_version -m pip freeze | grep fastdeploy | xargs pip uninstall -y
$py_version -m pip install $PY_FASTDEPLOY_PACKAGE -f https://www.paddlepaddle.org.cn/whl/fastdeploy_nightly_build.html
for((j=0;j<case_number;j+=1))
do
backend=${RUN_CASE[j]}
echo "Python Backend:" $backend
if [ "$backend" != "trt" ];then
$py_version infer_ppyoloe.py --model_dir $MODEL_PATH --image $IMAGE_PATH --device cpu --backend $backend >> py_cpu_result.txt
$py_version $COMPARE_SHELL --gt_path $GROUND_TRUTH_PATH --result_path py_cpu_result.txt --platform $PLATFORM --device cpu
fi
if [ "$DEVICE" = "gpu" ];then
if [ "$backend" = "trt" ];then
$py_version infer_ppyoloe.py --model_dir $MODEL_PATH --image $IMAGE_PATH --device gpu --backend $backend >> py_trt_result.txt
$py_version $COMPARE_SHELL --gt_path $GROUND_TRUTH_PATH --result_path py_trt_result.txt --platform $PLATFORM --device trt
else
$py_version infer_ppyoloe.py --model_dir $MODEL_PATH --image $IMAGE_PATH --device gpu --backend $backend >> py_gpu_result.txt
$py_version $COMPARE_SHELL --gt_path $GROUND_TRUTH_PATH --result_path py_gpu_result.txt --platform $PLATFORM --device gpu
fi
fi
done
done
wget -q https://fastdeploy.bj.bcebos.com/dev/cpp/$CPP_FASTDEPLOY_PACKAGE.tgz
tar xvf $CPP_FASTDEPLOY_PACKAGE.tgz
mkdir build && cd build
cmake .. -DFASTDEPLOY_INSTALL_DIR=${PWD}/../$CPP_FASTDEPLOY_PACKAGE
make -j
for((i=0;i<case_number;i+=1))
do
backend=${RUN_CASE[i]}
echo "Cpp Backend:" $backend
if [ "$backend" != "trt" ];then
./infer_ppyoloe_demo --model_dir=$MODEL_PATH --image_file=$IMAGE_PATH --device=cpu --backend=$backend >> cpp_cpu_result.txt
python $COMPARE_SHELL --gt_path $GROUND_TRUTH_PATH --result_path cpp_cpu_result.txt --platform $PLATFORM --device cpu
fi
if [ "$DEVICE" = "gpu" ];then
if [ "$backend" = "trt" ];then
./infer_ppyoloe_demo --model_dir=$MODEL_PATH --image_file=$IMAGE_PATH --device=gpu --backend=$backend >> cpp_trt_result.txt
python $COMPARE_SHELL --gt_path $GROUND_TRUTH_PATH --result_path cpp_trt_result.txt --platform $PLATFORM --device trt
else
./infer_ppyoloe_demo --model_dir=$MODEL_PATH --image_file=$IMAGE_PATH --device=gpu --backend=$backend >> cpp_gpu_result.txt
python $COMPARE_SHELL --gt_path $GROUND_TRUTH_PATH --result_path cpp_gpu_result.txt --platform $PLATFORM --device gpu
fi
fi
done
res_file="result.txt"
if [ ! -f $res_file ]; then
exit 0
else
cat $res_file
exit -1
fi