[Backend] support bechmark mode for runtime and backend (#1201)

* [backend] support bechmark mode for runtime and backend

* [backend] support bechmark mode for runtime and backend

* [pybind11] add benchmark methods pybind

* [pybind11] add benchmark methods pybind

* [Other] Update build scripts

* [Other] Update cmake/summary.cmake

* [Other] update build scripts

* [Other] add ENABLE_BENCHMARK option -> setup.py

* optimize backend time recording

* optimize backend time recording

* optimize trt backend time record

* [backend] optimze backend_time recording for trt

* [benchmark] remove redundant logs

* fixed ov_backend confilct

* [benchmark] fixed paddle_backend conflicts

* [benchmark] fixed paddle_backend conflicts

* [benchmark] fixed paddle_backend conflicts

* [benchmark] remove use_gpu option from ort backend option

* [benchmark] update benchmark_ppdet.py

* [benchmark] update benchmark_ppcls.py

* fixed lite backend conflicts

* [Lite] fixed lite xpu

* add benchmark macro

* add RUNTIME_PROFILE_LOOP macros

* add comments for RUNTIME_PROFILE macros

* add comments for new apis

* add comments for new apis

* update benchmark_ppdet.py

* afixed bugs

* remove unused codes

* optimize RUNTIME_PROFILE_LOOP macros

* optimize RUNTIME_PROFILE_LOOP macros

* add comments for benchmark option and result

* add docs for benchmark namespace
This commit is contained in:
DefTruth
2023-02-06 14:29:35 +08:00
committed by GitHub
parent 42d14e7119
commit f73a538f61
34 changed files with 741 additions and 91 deletions

View File

@@ -68,6 +68,7 @@ option(ENABLE_TEXT "Whether to enable text models usage." OFF)
option(ENABLE_FLYCV "Whether to enable flycv to boost image preprocess." OFF) option(ENABLE_FLYCV "Whether to enable flycv to boost image preprocess." OFF)
option(ENABLE_CVCUDA "Whether to enable NVIDIA CV-CUDA to boost image preprocess." OFF) option(ENABLE_CVCUDA "Whether to enable NVIDIA CV-CUDA to boost image preprocess." OFF)
option(ENABLE_ENCRYPTION "Whether to enable ENCRYPTION." OFF) option(ENABLE_ENCRYPTION "Whether to enable ENCRYPTION." OFF)
option(ENABLE_BENCHMARK "Whether to enable Benchmark mode." OFF)
option(WITH_ASCEND "Whether to compile for Huawei Ascend deploy." OFF) option(WITH_ASCEND "Whether to compile for Huawei Ascend deploy." OFF)
option(WITH_TIMVX "Whether to compile for TIMVX deploy." OFF) option(WITH_TIMVX "Whether to compile for TIMVX deploy." OFF)
option(WITH_KUNLUNXIN "Whether to compile for KunlunXin XPU deploy." OFF) option(WITH_KUNLUNXIN "Whether to compile for KunlunXin XPU deploy." OFF)

13
benchmark/.gitignore vendored Normal file
View File

@@ -0,0 +1,13 @@
*.tgz
*.zip
*.tar
*.tar.gz
*.tgz
*.jpg
*.png
*.jpeg
*.txt
*.log
yolov8_s_*
._yolov8_s_*
Mobile*

View File

@@ -17,7 +17,7 @@ import cv2
import os import os
import numpy as np import numpy as np
import time import time
from tqdm import tqdm
def parse_arguments(): def parse_arguments():
import argparse import argparse
@@ -35,11 +35,22 @@ def parse_arguments():
parser.add_argument( parser.add_argument(
"--device_id", type=int, default=0, help="device(gpu) id") "--device_id", type=int, default=0, help="device(gpu) id")
parser.add_argument( parser.add_argument(
"--iter_num", "--profile_mode",
type=str,
default="runtime",
help="runtime or end2end.")
parser.add_argument(
"--repeat",
required=True, required=True,
type=int, type=int,
default=300, default=1000,
help="number of iterations for computing performace.") help="number of repeats for profiling.")
parser.add_argument(
"--warmup",
required=True,
type=int,
default=50,
help="number of warmup for profiling.")
parser.add_argument( parser.add_argument(
"--device", "--device",
default="cpu", default="cpu",
@@ -59,6 +70,11 @@ def parse_arguments():
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="whether enable collect memory info") help="whether enable collect memory info")
parser.add_argument(
"--include_h2d_d2h",
type=ast.literal_eval,
default=False,
help="whether run profiling with h2d and d2h")
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -68,6 +84,8 @@ def build_option(args):
device = args.device device = args.device
backend = args.backend backend = args.backend
enable_trt_fp16 = args.enable_trt_fp16 enable_trt_fp16 = args.enable_trt_fp16
if args.profile_mode == "runtime":
option.enable_profiling(args.include_h2d_d2h, args.repeat, args.warmup)
option.set_cpu_thread_num(args.cpu_num_thread) option.set_cpu_thread_num(args.cpu_num_thread)
if device == "gpu": if device == "gpu":
option.use_gpu() option.use_gpu()
@@ -229,7 +247,6 @@ if __name__ == '__main__':
gpu_id = args.device_id gpu_id = args.device_id
enable_collect_memory_info = args.enable_collect_memory_info enable_collect_memory_info = args.enable_collect_memory_info
dump_result = dict() dump_result = dict()
end2end_statis = list()
cpu_mem = list() cpu_mem = list()
gpu_mem = list() gpu_mem = list()
gpu_util = list() gpu_util = list()
@@ -258,18 +275,26 @@ if __name__ == '__main__':
monitor = Monitor(enable_gpu, gpu_id) monitor = Monitor(enable_gpu, gpu_id)
monitor.start() monitor.start()
model.enable_record_time_of_runtime()
im_ori = cv2.imread(args.image) im_ori = cv2.imread(args.image)
for i in range(args.iter_num): if args.profile_mode == "runtime":
im = im_ori result = model.predict(im_ori)
profile_time = model.get_profile_time()
dump_result["runtime"] = profile_time * 1000
f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
print("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
else:
# end2end
for i in range(args.warmup):
result = model.predict(im_ori)
start = time.time() start = time.time()
result = model.predict(im) for i in tqdm(range(args.repeat)):
end2end_statis.append(time.time() - start) result = model.predict(im_ori)
end = time.time()
dump_result["end2end"] = ((end - start) / args.repeat) * 1000.0
f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"])))
print("End2End(ms): {} \n".format(str(dump_result["end2end"])))
runtime_statis = model.print_statis_info_of_runtime()
warmup_iter = args.iter_num // 5
end2end_statis_repeat = end2end_statis[warmup_iter:]
if enable_collect_memory_info: if enable_collect_memory_info:
monitor.stop() monitor.stop()
mem_info = monitor.output() mem_info = monitor.output()
@@ -280,13 +305,6 @@ if __name__ == '__main__':
dump_result["gpu_util"] = mem_info['gpu'][ dump_result["gpu_util"] = mem_info['gpu'][
'utilization.gpu'] if 'gpu' in mem_info else 0 'utilization.gpu'] if 'gpu' in mem_info else 0
dump_result["runtime"] = runtime_statis["avg_time"] * 1000
dump_result["end2end"] = np.mean(end2end_statis_repeat) * 1000
f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"])))
print("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
print("End2End(ms): {} \n".format(str(dump_result["end2end"])))
if enable_collect_memory_info: if enable_collect_memory_info:
f.writelines("cpu_rss_mb: {} \n".format( f.writelines("cpu_rss_mb: {} \n".format(
str(dump_result["cpu_rss_mb"]))) str(dump_result["cpu_rss_mb"])))
@@ -297,7 +315,8 @@ if __name__ == '__main__':
print("cpu_rss_mb: {} \n".format(str(dump_result["cpu_rss_mb"]))) print("cpu_rss_mb: {} \n".format(str(dump_result["cpu_rss_mb"])))
print("gpu_rss_mb: {} \n".format(str(dump_result["gpu_rss_mb"]))) print("gpu_rss_mb: {} \n".format(str(dump_result["gpu_rss_mb"])))
print("gpu_util: {} \n".format(str(dump_result["gpu_util"]))) print("gpu_util: {} \n".format(str(dump_result["gpu_util"])))
except: except Exception as e:
f.writelines("!!!!!Infer Failed\n") f.writelines("!!!!!Infer Failed\n")
raise e
f.close() f.close()

View File

@@ -17,6 +17,7 @@ import cv2
import os import os
import numpy as np import numpy as np
import time import time
from sympy import EX
from tqdm import tqdm from tqdm import tqdm
def parse_arguments(): def parse_arguments():
@@ -24,7 +25,7 @@ def parse_arguments():
import ast import ast
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--model", required=True, help="Path of PaddleDetection model.") "--model", required=True, help="Path of PaddleClas model.")
parser.add_argument( parser.add_argument(
"--image", type=str, required=False, help="Path of test image file.") "--image", type=str, required=False, help="Path of test image file.")
parser.add_argument( parser.add_argument(
@@ -35,20 +36,31 @@ def parse_arguments():
parser.add_argument( parser.add_argument(
"--device_id", type=int, default=0, help="device(gpu) id") "--device_id", type=int, default=0, help="device(gpu) id")
parser.add_argument( parser.add_argument(
"--iter_num", "--profile_mode",
type=str,
default="runtime",
help="runtime or end2end.")
parser.add_argument(
"--repeat",
required=True, required=True,
type=int, type=int,
default=300, default=1000,
help="number of iterations for computing performace.") help="number of repeats for profiling.")
parser.add_argument(
"--warmup",
required=True,
type=int,
default=50,
help="number of warmup for profiling.")
parser.add_argument( parser.add_argument(
"--device", "--device",
default="cpu", default="cpu",
help="Type of inference device, support 'cpu', 'gpu', 'kunlunxin', 'ascend' etc.") help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument( parser.add_argument(
"--backend", "--backend",
type=str, type=str,
default="default", default="default",
help="inference backend, default, ort, ov, trt, paddle, paddle_trt, lite.") help="inference backend, default, ort, ov, trt, paddle, paddle_trt.")
parser.add_argument( parser.add_argument(
"--enable_trt_fp16", "--enable_trt_fp16",
type=ast.literal_eval, type=ast.literal_eval,
@@ -58,12 +70,17 @@ def parse_arguments():
"--enable_lite_fp16", "--enable_lite_fp16",
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="whether enable fp16 in lite backend") help="whether enable fp16 in Paddle Lite backend")
parser.add_argument( parser.add_argument(
"--enable_collect_memory_info", "--enable_collect_memory_info",
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="whether enable collect memory info") help="whether enable collect memory info")
parser.add_argument(
"--include_h2d_d2h",
type=ast.literal_eval,
default=False,
help="whether run profiling with h2d and d2h")
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -74,6 +91,8 @@ def build_option(args):
backend = args.backend backend = args.backend
enable_trt_fp16 = args.enable_trt_fp16 enable_trt_fp16 = args.enable_trt_fp16
enable_lite_fp16 = args.enable_lite_fp16 enable_lite_fp16 = args.enable_lite_fp16
if args.profile_mode == "runtime":
option.enable_profiling(args.include_h2d_d2h, args.repeat, args.warmup)
option.set_cpu_thread_num(args.cpu_num_thread) option.set_cpu_thread_num(args.cpu_num_thread)
if device == "gpu": if device == "gpu":
option.use_gpu() option.use_gpu()
@@ -266,8 +285,12 @@ if __name__ == '__main__':
gpu_id = args.device_id gpu_id = args.device_id
enable_collect_memory_info = args.enable_collect_memory_info enable_collect_memory_info = args.enable_collect_memory_info
enable_record_time_of_backend = args.enable_record_time_of_backend
backend_repeat = args.backend_repeat
dump_result = dict() dump_result = dict()
end2end_statis = list() end2end_statis = list()
prepost_statis = list()
h2d_d2h_statis = list()
cpu_mem = list() cpu_mem = list()
gpu_mem = list() gpu_mem = list()
gpu_util = list() gpu_util = list()
@@ -317,18 +340,26 @@ if __name__ == '__main__':
monitor = Monitor(enable_gpu, gpu_id) monitor = Monitor(enable_gpu, gpu_id)
monitor.start() monitor.start()
model.enable_record_time_of_runtime()
im_ori = cv2.imread(args.image) im_ori = cv2.imread(args.image)
for i in tqdm(range(args.iter_num)): if args.profile_mode == "runtime":
im = im_ori result = model.predict(im_ori)
profile_time = model.get_profile_time()
dump_result["runtime"] = profile_time * 1000
f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
print("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
else:
# end2end
for i in range(args.warmup):
result = model.predict(im_ori)
start = time.time() start = time.time()
result = model.predict(im) for i in tqdm(range(args.repeat)):
end2end_statis.append(time.time() - start) result = model.predict(im_ori)
end = time.time()
dump_result["end2end"] = ((end - start) / args.repeat) * 1000.0
f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"])))
print("End2End(ms): {} \n".format(str(dump_result["end2end"])))
runtime_statis = model.print_statis_info_of_runtime()
warmup_iter = args.iter_num // 5
end2end_statis_repeat = end2end_statis[warmup_iter:]
if enable_collect_memory_info: if enable_collect_memory_info:
monitor.stop() monitor.stop()
mem_info = monitor.output() mem_info = monitor.output()
@@ -339,13 +370,6 @@ if __name__ == '__main__':
dump_result["gpu_util"] = mem_info['gpu'][ dump_result["gpu_util"] = mem_info['gpu'][
'utilization.gpu'] if 'gpu' in mem_info else 0 'utilization.gpu'] if 'gpu' in mem_info else 0
dump_result["runtime"] = runtime_statis["avg_time"] * 1000
dump_result["end2end"] = np.mean(end2end_statis_repeat) * 1000
f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"])))
print("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
print("End2End(ms): {} \n".format(str(dump_result["end2end"])))
if enable_collect_memory_info: if enable_collect_memory_info:
f.writelines("cpu_rss_mb: {} \n".format( f.writelines("cpu_rss_mb: {} \n".format(
str(dump_result["cpu_rss_mb"]))) str(dump_result["cpu_rss_mb"])))
@@ -356,7 +380,8 @@ if __name__ == '__main__':
print("cpu_rss_mb: {} \n".format(str(dump_result["cpu_rss_mb"]))) print("cpu_rss_mb: {} \n".format(str(dump_result["cpu_rss_mb"])))
print("gpu_rss_mb: {} \n".format(str(dump_result["gpu_rss_mb"]))) print("gpu_rss_mb: {} \n".format(str(dump_result["gpu_rss_mb"])))
print("gpu_util: {} \n".format(str(dump_result["gpu_util"]))) print("gpu_util: {} \n".format(str(dump_result["gpu_util"])))
except: except Exception as e:
f.writelines("!!!!!Infer Failed\n") f.writelines("!!!!!Infer Failed\n")
raise e
f.close() f.close()

View File

@@ -39,6 +39,7 @@ function(fastdeploy_summary)
message(STATUS " ENABLE_POROS_BACKEND : ${ENABLE_POROS_BACKEND}") message(STATUS " ENABLE_POROS_BACKEND : ${ENABLE_POROS_BACKEND}")
message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}") message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}")
message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}") message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}")
message(STATUS " ENABLE_BENCHMARK : ${ENABLE_BENCHMARK}")
message(STATUS " WITH_GPU : ${WITH_GPU}") message(STATUS " WITH_GPU : ${WITH_GPU}")
message(STATUS " WITH_ASCEND : ${WITH_ASCEND}") message(STATUS " WITH_ASCEND : ${WITH_ASCEND}")
message(STATUS " WITH_TIMVX : ${WITH_TIMVX}") message(STATUS " WITH_TIMVX : ${WITH_TIMVX}")

View File

@@ -0,0 +1,86 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/config.h"
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/benchmark/option.h"
#include "fastdeploy/benchmark/results.h"
#ifdef ENABLE_BENCHMARK
#define __RUNTIME_PROFILE_LOOP_BEGIN(option, base_loop) \
int __p_loop = (base_loop); \
const bool __p_enable_profile = option.enable_profile; \
const bool __p_include_h2d_d2h = option.include_h2d_d2h; \
const int __p_repeats = option.repeats; \
const int __p_warmup = option.warmup; \
if (__p_enable_profile && (!__p_include_h2d_d2h)) { \
__p_loop = (__p_repeats) + (__p_warmup); \
FDINFO << option << std::endl; \
} \
TimeCounter __p_tc; \
bool __p_tc_start = false; \
for (int __p_i = 0; __p_i < __p_loop; ++__p_i) { \
if (__p_i >= (__p_warmup) && (!__p_tc_start)) { \
__p_tc.Start(); \
__p_tc_start = true; \
} \
#define __RUNTIME_PROFILE_LOOP_END(result) \
} \
if ((__p_enable_profile && (!__p_include_h2d_d2h))) { \
if (__p_tc_start) { \
__p_tc.End(); \
double __p_tc_duration = __p_tc.Duration(); \
result.time_of_runtime = \
__p_tc_duration / static_cast<double>(__p_repeats); \
} \
}
#define __RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(option, base_loop) \
int __p_loop_h = (base_loop); \
const bool __p_enable_profile_h = option.enable_profile; \
const bool __p_include_h2d_d2h_h = option.include_h2d_d2h; \
const int __p_repeats_h = option.repeats; \
const int __p_warmup_h = option.warmup; \
if (__p_enable_profile_h && __p_include_h2d_d2h_h) { \
__p_loop_h = (__p_repeats_h) + (__p_warmup_h); \
FDINFO << option << std::endl; \
} \
TimeCounter __p_tc_h; \
bool __p_tc_start_h = false; \
for (int __p_i_h = 0; __p_i_h < __p_loop_h; ++__p_i_h) { \
if (__p_i_h >= (__p_warmup_h) && (!__p_tc_start_h)) { \
__p_tc_h.Start(); \
__p_tc_start_h = true; \
} \
#define __RUNTIME_PROFILE_LOOP_H2D_D2H_END(result) \
} \
if ((__p_enable_profile_h && __p_include_h2d_d2h_h)) { \
if (__p_tc_start_h) { \
__p_tc_h.End(); \
double __p_tc_duration_h = __p_tc_h.Duration(); \
result.time_of_runtime = \
__p_tc_duration_h / static_cast<double>(__p_repeats_h); \
} \
}
#else
#define __RUNTIME_PROFILE_LOOP_BEGIN(option, base_loop) \
for (int __p_i = 0; __p_i < (base_loop); ++ __p_i) {
#define __RUNTIME_PROFILE_LOOP_END(result) }
#define __RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(option, base_loop) \
for (int __p_i_h = 0; __p_i_h < (base_loop); ++ __p_i_h) {
#define __RUNTIME_PROFILE_LOOP_H2D_D2H_END(result) }
#endif

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace fastdeploy {
/** \brief All C++ FastDeploy benchmark profile APIs are defined inside this namespace
*
*/
namespace benchmark {
/*! @brief Option object used to control the behavior of the benchmark profiling.
*/
struct BenchmarkOption {
int warmup = 50; ///< Warmup for backend inference.
int repeats = 100; ///< Repeats for backend inference.
bool enable_profile = false; ///< Whether to use profile or not.
bool include_h2d_d2h = false; ///< Whether to include time of H2D_D2H for time of runtime.
friend std::ostream& operator<<(
std::ostream& output, const BenchmarkOption &option) {
if (!option.include_h2d_d2h) {
output << "Running profiling for Runtime "
<< "without H2D and D2H, ";
} else {
output << "Running profiling for Runtime "
<< "with H2D and D2H, ";
}
output << "Repeats: " << option.repeats << ", "
<< "Warmup: " << option.warmup;
return output;
}
};
} // namespace benchmark
} // namespace fastdeploy

View File

@@ -0,0 +1,27 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace fastdeploy {
namespace benchmark {
/*! @brief Result object used to record the time of runtime after benchmark profiling is done.
*/
struct BenchmarkResult {
///< Means pure_backend_time+time_of_h2d_d2h(if include_h2d_d2h=true).
double time_of_runtime = 0.0f;
};
} // namespace benchmark
} // namespace fastdeploy

View File

@@ -56,3 +56,7 @@
#ifndef ENABLE_TEXT #ifndef ENABLE_TEXT
#cmakedefine ENABLE_TEXT #cmakedefine ENABLE_TEXT
#endif #endif
#ifndef ENABLE_BENCHMARK
#cmakedefine ENABLE_BENCHMARK
#endif

View File

@@ -31,7 +31,8 @@ std::string Str(const std::vector<Backend>& backends) {
return oss.str(); return oss.str();
} }
bool IsSupported(const std::vector<Backend>& backends, Backend backend) { bool CheckBackendSupported(const std::vector<Backend>& backends,
Backend backend) {
for (size_t i = 0; i < backends.size(); ++i) { for (size_t i = 0; i < backends.size(); ++i) {
if (backends[i] == backend) { if (backends[i] == backend) {
return true; return true;
@@ -40,6 +41,22 @@ bool IsSupported(const std::vector<Backend>& backends, Backend backend) {
return false; return false;
} }
bool FastDeployModel::IsSupported(const std::vector<Backend>& backends,
Backend backend) {
#ifdef ENABLE_BENCHMARK
if (runtime_option.benchmark_option.enable_profile) {
FDWARNING << "In benchmark mode, we don't check to see if "
<< "the backend [" << backend
<< "] is supported for current model!"
<< std::endl;
return true;
}
return CheckBackendSupported(backends, backend);
#else
return CheckBackendSupported(backends, backend);
#endif
}
bool FastDeployModel::InitRuntimeWithSpecifiedBackend() { bool FastDeployModel::InitRuntimeWithSpecifiedBackend() {
if (!IsBackendAvailable(runtime_option.backend)) { if (!IsBackendAvailable(runtime_option.backend)) {
FDERROR << runtime_option.backend FDERROR << runtime_option.backend
@@ -367,12 +384,13 @@ bool FastDeployModel::Infer(std::vector<FDTensor>& input_tensors,
tc.End(); tc.End();
if (time_of_runtime_.size() > 50000) { if (time_of_runtime_.size() > 50000) {
FDWARNING << "There are already 50000 records of runtime, will force to " FDWARNING << "There are already 50000 records of runtime, will force to "
"disable record time of runtime now." "disable record time of runtime now."
<< std::endl; << std::endl;
enable_record_time_of_runtime_ = false; enable_record_time_of_runtime_ = false;
} }
time_of_runtime_.push_back(tc.Duration()); time_of_runtime_.push_back(tc.Duration());
} }
return ret; return ret;
} }
@@ -416,6 +434,7 @@ std::map<std::string, float> FastDeployModel::PrintStatisInfoOfRuntime() {
statis_info_of_runtime_dict["warmup_iter"] = warmup_iter; statis_info_of_runtime_dict["warmup_iter"] = warmup_iter;
statis_info_of_runtime_dict["avg_time"] = avg_time; statis_info_of_runtime_dict["avg_time"] = avg_time;
statis_info_of_runtime_dict["iterations"] = time_of_runtime_.size(); statis_info_of_runtime_dict["iterations"] = time_of_runtime_.size();
return statis_info_of_runtime_dict; return statis_info_of_runtime_dict;
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -75,7 +75,7 @@ class FASTDEPLOY_DECL FastDeployModel {
return runtime_initialized_ && initialized; return runtime_initialized_ && initialized;
} }
/** \brief This is a debug interface, used to record the time of backend runtime /** \brief This is a debug interface, used to record the time of runtime (backend + h2d + d2h)
* *
* example code @code * example code @code
* auto model = fastdeploy::vision::PPYOLOE("model.pdmodel", "model.pdiparams", "infer_cfg.yml"); * auto model = fastdeploy::vision::PPYOLOE("model.pdmodel", "model.pdiparams", "infer_cfg.yml");
@@ -98,7 +98,7 @@ class FASTDEPLOY_DECL FastDeployModel {
enable_record_time_of_runtime_ = true; enable_record_time_of_runtime_ = true;
} }
/** \brief Disable to record the time of backend runtime, see `EnableRecordTimeOfRuntime()` for more detail /** \brief Disable to record the time of runtime, see `EnableRecordTimeOfRuntime()` for more detail
*/ */
virtual void DisableRecordTimeOfRuntime() { virtual void DisableRecordTimeOfRuntime() {
enable_record_time_of_runtime_ = false; enable_record_time_of_runtime_ = false;
@@ -113,6 +113,11 @@ class FASTDEPLOY_DECL FastDeployModel {
virtual bool EnabledRecordTimeOfRuntime() { virtual bool EnabledRecordTimeOfRuntime() {
return enable_record_time_of_runtime_; return enable_record_time_of_runtime_;
} }
/** \brief Get profile time of Runtime after the profile process is done.
*/
virtual double GetProfileTime() {
return runtime_->GetProfileTime();
}
/** \brief Release reused input/output buffers /** \brief Release reused input/output buffers
*/ */
@@ -153,13 +158,13 @@ class FASTDEPLOY_DECL FastDeployModel {
bool CreateTimVXBackend(); bool CreateTimVXBackend();
bool CreateKunlunXinBackend(); bool CreateKunlunXinBackend();
bool CreateASCENDBackend(); bool CreateASCENDBackend();
bool IsSupported(const std::vector<Backend>& backends,
Backend backend);
std::shared_ptr<Runtime> runtime_; std::shared_ptr<Runtime> runtime_;
bool runtime_initialized_ = false; bool runtime_initialized_ = false;
// whether to record inference time // whether to record inference time
bool enable_record_time_of_runtime_ = false; bool enable_record_time_of_runtime_ = false;
// record inference time for backend
std::vector<double> time_of_runtime_; std::vector<double> time_of_runtime_;
}; };

View File

@@ -30,6 +30,8 @@ void BindFDModel(pybind11::module& m) {
&FastDeployModel::DisableRecordTimeOfRuntime) &FastDeployModel::DisableRecordTimeOfRuntime)
.def("print_statis_info_of_runtime", .def("print_statis_info_of_runtime",
&FastDeployModel::PrintStatisInfoOfRuntime) &FastDeployModel::PrintStatisInfoOfRuntime)
.def("get_profile_time",
&FastDeployModel::GetProfileTime)
.def("initialized", &FastDeployModel::Initialized) .def("initialized", &FastDeployModel::Initialized)
.def_readwrite("runtime_option", &FastDeployModel::runtime_option) .def_readwrite("runtime_option", &FastDeployModel::runtime_option)
.def_readwrite("valid_cpu_backends", &FastDeployModel::valid_cpu_backends) .def_readwrite("valid_cpu_backends", &FastDeployModel::valid_cpu_backends)

View File

@@ -77,6 +77,8 @@ void BindRuntime(pybind11::module& m) {
.def("set_ipu_config", &RuntimeOption::SetIpuConfig) .def("set_ipu_config", &RuntimeOption::SetIpuConfig)
.def("delete_paddle_backend_pass", .def("delete_paddle_backend_pass",
&RuntimeOption::DeletePaddleBackendPass) &RuntimeOption::DeletePaddleBackendPass)
.def("enable_profiling", &RuntimeOption::EnableProfiling)
.def("disable_profiling", &RuntimeOption::DisableProfiling)
.def("disable_paddle_trt_ops", &RuntimeOption::DisablePaddleTrtOPs) .def("disable_paddle_trt_ops", &RuntimeOption::DisablePaddleTrtOPs)
.def_readwrite("model_file", &RuntimeOption::model_file) .def_readwrite("model_file", &RuntimeOption::model_file)
.def_readwrite("params_file", &RuntimeOption::params_file) .def_readwrite("params_file", &RuntimeOption::params_file)
@@ -217,6 +219,7 @@ void BindRuntime(pybind11::module& m) {
.def("num_outputs", &Runtime::NumOutputs) .def("num_outputs", &Runtime::NumOutputs)
.def("get_input_info", &Runtime::GetInputInfo) .def("get_input_info", &Runtime::GetInputInfo)
.def("get_output_info", &Runtime::GetOutputInfo) .def("get_output_info", &Runtime::GetOutputInfo)
.def("get_profile_time", &Runtime::GetProfileTime)
.def_readonly("option", &Runtime::option); .def_readonly("option", &Runtime::option);
pybind11::enum_<Backend>(m, "Backend", pybind11::arithmetic(), pybind11::enum_<Backend>(m, "Backend", pybind11::arithmetic(),

View File

@@ -22,6 +22,7 @@
#include "fastdeploy/core/fd_tensor.h" #include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/core/fd_type.h" #include "fastdeploy/core/fd_type.h"
#include "fastdeploy/runtime/runtime_option.h" #include "fastdeploy/runtime/runtime_option.h"
#include "fastdeploy/benchmark/benchmark.h"
namespace fastdeploy { namespace fastdeploy {
@@ -79,7 +80,6 @@ class BaseBackend {
virtual bool Infer(std::vector<FDTensor>& inputs, virtual bool Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs, std::vector<FDTensor>* outputs,
bool copy_to_fd = true) = 0; bool copy_to_fd = true) = 0;
// Optional: For those backends which can share memory // Optional: For those backends which can share memory
// while creating multiple inference engines with same model file // while creating multiple inference engines with same model file
virtual std::unique_ptr<BaseBackend> Clone(RuntimeOption &runtime_option, virtual std::unique_ptr<BaseBackend> Clone(RuntimeOption &runtime_option,
@@ -88,6 +88,70 @@ class BaseBackend {
FDERROR << "Clone no support" << std::endl; FDERROR << "Clone no support" << std::endl;
return nullptr; return nullptr;
} }
benchmark::BenchmarkOption benchmark_option_;
benchmark::BenchmarkResult benchmark_result_;
}; };
/** \brief Macros for Runtime benchmark profiling.
* The param 'base_loop' for 'RUNTIME_PROFILE_LOOP_BEGIN'
* indicates that the least number of times the loop
* will repeat when profiling mode is not enabled.
* In most cases, the value should be 1, i.e., results are
* obtained by running the inference process once, when
* the profile mode is turned off, such as ONNX Runtime,
* OpenVINO, TensorRT, Paddle Inference, Paddle Lite,
* RKNPU2, SOPHGO etc.
*
* example code @code
* // OpenVINOBackend::Infer
* RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
* // do something ....
* RUNTIME_PROFILE_LOOP_BEGIN(1)
* // The codes which wrapped by 'BEGIN(1) ~ END' scope
* // will only run once when profiling mode is not enabled.
* request_.infer();
* RUNTIME_PROFILE_LOOP_END
* // do something ....
* RUNTIME_PROFILE_LOOP_H2D_D2H_END
*
* @endcode In this case, No global variables inside a function
* are wrapped by BEGIN and END, which may be required for
* subsequent tasks. But, some times we need to set 'base_loop'
* as 0, such as POROS.
*
* * example code @code
* // PorosBackend::Infer
* RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
* // do something ....
* RUNTIME_PROFILE_LOOP_BEGIN(0) // set 'base_loop' as 0
* // The codes which wrapped by 'BEGIN(0) ~ END' scope
* // will not run when profiling mode is not enabled.
* auto poros_outputs = _poros_module->forward(poros_inputs);
* RUNTIME_PROFILE_LOOP_END
* // Run another inference beyond the scope of 'BEGIN ~ END'
* // to get valid outputs for subsequent tasks.
* auto poros_outputs = _poros_module->forward(poros_inputs);
* // do something .... will use 'poros_outputs' ...
* if (poros_outputs.isTensor()) {
* // ...
* }
* RUNTIME_PROFILE_LOOP_H2D_D2H_END
*
* @endcode In this case, 'poros_outputs' inside a function
* are wrapped by BEGIN and END, which may be required for
* subsequent tasks. So, we set 'base_loop' as 0 and lanuch
* another infer to get the valid outputs beyond the scope
* of 'BEGIN ~ END' for subsequent tasks.
*/
#define RUNTIME_PROFILE_LOOP_BEGIN(base_loop) \
__RUNTIME_PROFILE_LOOP_BEGIN(benchmark_option_, (base_loop))
#define RUNTIME_PROFILE_LOOP_END \
__RUNTIME_PROFILE_LOOP_END(benchmark_result_)
#define RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN \
__RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(benchmark_option_, 1)
#define RUNTIME_PROFILE_LOOP_H2D_D2H_END \
__RUNTIME_PROFILE_LOOP_H2D_D2H_END(benchmark_result_)
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -13,23 +13,6 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/runtime/backends/lite/lite_backend.h" #include "fastdeploy/runtime/backends/lite/lite_backend.h"
// https://github.com/PaddlePaddle/Paddle-Lite/issues/8290
// When compiling the FastDeploy dynamic library, namely,
// WITH_STATIC_LIB=OFF, and depending on the Paddle Lite
// static library, you need to include the fake registration
// codes of Paddle Lite. When you compile the FastDeploy static
// library and depends on the Paddle Lite static library,
// WITH_STATIC_LIB=ON, you do not need to include the fake
// registration codes for Paddle Lite, but wait until you
// use the FastDeploy static library.
#if (defined(WITH_LITE_STATIC) && (!defined(WITH_STATIC_LIB)))
#warning You are compiling the FastDeploy dynamic library with \
Paddle Lite static lib We will automatically add some registration \
codes for ops, kernels and passes for Paddle Lite.
#include "paddle_use_kernels.h" // NOLINT
#include "paddle_use_ops.h" // NOLINT
#include "paddle_use_passes.h" // NOLINT
#endif
#include <cstring> #include <cstring>
@@ -156,4 +139,5 @@ void LiteBackend::ConfigureNNAdapter(const LiteBackendOption& option) {
config_.set_nnadapter_dynamic_shape_info(option.nnadapter_dynamic_shape_info); config_.set_nnadapter_dynamic_shape_info(option.nnadapter_dynamic_shape_info);
} }
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -100,7 +100,7 @@ bool LiteBackend::InitFromPaddle(const std::string& model_file,
auto shape = tensor->shape(); auto shape = tensor->shape();
info.shape.assign(shape.begin(), shape.end()); info.shape.assign(shape.begin(), shape.end());
info.name = output_names[i]; info.name = output_names[i];
if (!option_.device == Device::KUNLUNXIN) { if (option_.device != Device::KUNLUNXIN) {
info.dtype = LiteDataTypeToFD(tensor->precision()); info.dtype = LiteDataTypeToFD(tensor->precision());
} }
outputs_desc_.emplace_back(info); outputs_desc_.emplace_back(info);
@@ -136,6 +136,8 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
<< inputs_desc_.size() << ")." << std::endl; << inputs_desc_.size() << ")." << std::endl;
return false; return false;
} }
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto iter = inputs_order_.find(inputs[i].name); auto iter = inputs_order_.find(inputs[i].name);
if (iter == inputs_order_.end()) { if (iter == inputs_order_.end()) {
@@ -143,6 +145,7 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
<< " in loaded model." << std::endl; << " in loaded model." << std::endl;
return false; return false;
} }
auto tensor = predictor_->GetInput(iter->second); auto tensor = predictor_->GetInput(iter->second);
// Adjust dims only, allocate lazy. // Adjust dims only, allocate lazy.
tensor->Resize(inputs[i].shape); tensor->Resize(inputs[i].shape);
@@ -175,7 +178,9 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
} }
} }
RUNTIME_PROFILE_LOOP_BEGIN(1)
predictor_->Run(); predictor_->Run();
RUNTIME_PROFILE_LOOP_END
outputs->resize(outputs_desc_.size()); outputs->resize(outputs_desc_.size());
for (size_t i = 0; i < outputs_desc_.size(); ++i) { for (size_t i = 0; i < outputs_desc_.size(); ++i) {
@@ -188,6 +193,7 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
memcpy((*outputs)[i].MutableData(), tensor->data<void>(), memcpy((*outputs)[i].MutableData(), tensor->data<void>(),
(*outputs)[i].Nbytes()); (*outputs)[i].Nbytes());
} }
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true; return true;
} }

0
fastdeploy/runtime/backends/lite/lite_backend.h Executable file → Normal file
View File

View File

@@ -375,6 +375,7 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
return false; return false;
} }
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
ov::Shape shape(inputs[i].shape.begin(), inputs[i].shape.end()); ov::Shape shape(inputs[i].shape.begin(), inputs[i].shape.end());
ov::Tensor ov_tensor(FDDataTypeToOV(inputs[i].dtype), shape, ov::Tensor ov_tensor(FDDataTypeToOV(inputs[i].dtype), shape,
@@ -382,7 +383,9 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
request_.set_tensor(inputs[i].name, ov_tensor); request_.set_tensor(inputs[i].name, ov_tensor);
} }
RUNTIME_PROFILE_LOOP_BEGIN(1)
request_.infer(); request_.infer();
RUNTIME_PROFILE_LOOP_END
outputs->resize(output_infos_.size()); outputs->resize(output_infos_.size());
for (size_t i = 0; i < output_infos_.size(); ++i) { for (size_t i = 0; i < output_infos_.size(); ++i) {
@@ -403,6 +406,7 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
out_tensor.data(), Device::CPU); out_tensor.data(), Device::CPU);
} }
} }
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true; return true;
} }

View File

@@ -13,9 +13,6 @@
// limitations under the License. // limitations under the License.
#include "fastdeploy/runtime/backends/ort/ort_backend.h" #include "fastdeploy/runtime/backends/ort/ort_backend.h"
#include <memory>
#include "fastdeploy/core/float16.h" #include "fastdeploy/core/float16.h"
#include "fastdeploy/runtime/backends/ort/ops/adaptive_pool2d.h" #include "fastdeploy/runtime/backends/ort/ops/adaptive_pool2d.h"
#include "fastdeploy/runtime/backends/ort/ops/multiclass_nms.h" #include "fastdeploy/runtime/backends/ort/ops/multiclass_nms.h"
@@ -25,6 +22,9 @@
#include "paddle2onnx/converter.h" #include "paddle2onnx/converter.h"
#endif #endif
#include <memory>
namespace fastdeploy { namespace fastdeploy {
std::vector<OrtCustomOp*> OrtBackend::custom_operators_ = std::vector<OrtCustomOp*> OrtBackend::custom_operators_ =
@@ -258,6 +258,7 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
} }
// from FDTensor to Ort Inputs // from FDTensor to Ort Inputs
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto ort_value = CreateOrtValue(inputs[i], option_.device == Device::GPU); auto ort_value = CreateOrtValue(inputs[i], option_.device == Device::GPU);
binding_->BindInput(inputs[i].name.c_str(), ort_value); binding_->BindInput(inputs[i].name.c_str(), ort_value);
@@ -270,12 +271,14 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
} }
// Inference with inputs // Inference with inputs
RUNTIME_PROFILE_LOOP_BEGIN(1)
try { try {
session_.Run({}, *(binding_.get())); session_.Run({}, *(binding_.get()));
} catch (const std::exception& e) { } catch (const std::exception& e) {
FDERROR << "Failed to Infer: " << e.what() << std::endl; FDERROR << "Failed to Infer: " << e.what() << std::endl;
return false; return false;
} }
RUNTIME_PROFILE_LOOP_END
// Convert result after inference // Convert result after inference
std::vector<Ort::Value> ort_outputs = binding_->GetOutputValues(); std::vector<Ort::Value> ort_outputs = binding_->GetOutputValues();
@@ -284,7 +287,7 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
OrtValueToFDTensor(ort_outputs[i], &((*outputs)[i]), outputs_desc_[i].name, OrtValueToFDTensor(ort_outputs[i], &((*outputs)[i]), outputs_desc_[i].name,
copy_to_fd); copy_to_fd);
} }
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true; return true;
} }

View File

@@ -222,12 +222,15 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
return false; return false;
} }
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
for (size_t i = 0; i < inputs.size(); ++i) { for (size_t i = 0; i < inputs.size(); ++i) {
auto handle = predictor_->GetInputHandle(inputs[i].name); auto handle = predictor_->GetInputHandle(inputs[i].name);
ShareTensorFromFDTensor(handle.get(), inputs[i]); ShareTensorFromFDTensor(handle.get(), inputs[i]);
} }
RUNTIME_PROFILE_LOOP_BEGIN(1)
predictor_->Run(); predictor_->Run();
RUNTIME_PROFILE_LOOP_END
// output share backend memory only support CPU or GPU // output share backend memory only support CPU or GPU
if (option_.use_ipu) { if (option_.use_ipu) {
@@ -241,6 +244,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
} }
PaddleTensorToFDTensor(handle, &((*outputs)[i]), copy_to_fd); PaddleTensorToFDTensor(handle, &((*outputs)[i]), copy_to_fd);
} }
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true; return true;
} }

View File

@@ -287,14 +287,18 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
BuildTrtEngine(); BuildTrtEngine();
} }
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
cudaSetDevice(option_.gpu_id); cudaSetDevice(option_.gpu_id);
SetInputs(inputs); SetInputs(inputs);
AllocateOutputsBuffer(outputs, copy_to_fd); AllocateOutputsBuffer(outputs, copy_to_fd);
RUNTIME_PROFILE_LOOP_BEGIN(1)
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) { if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
FDERROR << "Failed to Infer with TensorRT." << std::endl; FDERROR << "Failed to Infer with TensorRT." << std::endl;
return false; return false;
} }
RUNTIME_PROFILE_LOOP_END
for (size_t i = 0; i < outputs->size(); ++i) { for (size_t i = 0; i < outputs->size(); ++i) {
// if the final output tensor's dtype is different from the model output // if the final output tensor's dtype is different from the model output
// tensor's dtype, then we need cast the data to the final output's dtype // tensor's dtype, then we need cast the data to the final output's dtype
@@ -335,7 +339,7 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess, FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream."); "[ERROR] Error occurs while sync cuda stream.");
} }
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true; return true;
} }

View File

@@ -275,6 +275,8 @@ void Runtime::CreatePaddleBackend() {
#endif #endif
backend_ = utils::make_unique<PaddleBackend>(); backend_ = utils::make_unique<PaddleBackend>();
auto casted_backend = dynamic_cast<PaddleBackend*>(backend_.get()); auto casted_backend = dynamic_cast<PaddleBackend*>(backend_.get());
casted_backend->benchmark_option_ = option.benchmark_option;
if (pd_option.model_from_memory_) { if (pd_option.model_from_memory_) {
FDASSERT(casted_backend->InitFromPaddle(option.model_file, FDASSERT(casted_backend->InitFromPaddle(option.model_file,
option.params_file, pd_option), option.params_file, pd_option),
@@ -303,6 +305,7 @@ void Runtime::CreatePaddleBackend() {
void Runtime::CreateOpenVINOBackend() { void Runtime::CreateOpenVINOBackend() {
#ifdef ENABLE_OPENVINO_BACKEND #ifdef ENABLE_OPENVINO_BACKEND
backend_ = utils::make_unique<OpenVINOBackend>(); backend_ = utils::make_unique<OpenVINOBackend>();
backend_->benchmark_option_ = option.benchmark_option;
FDASSERT(backend_->Init(option), "Failed to initialize OpenVINOBackend."); FDASSERT(backend_->Init(option), "Failed to initialize OpenVINOBackend.");
#else #else
FDASSERT(false, FDASSERT(false,
@@ -316,6 +319,8 @@ void Runtime::CreateOpenVINOBackend() {
void Runtime::CreateOrtBackend() { void Runtime::CreateOrtBackend() {
#ifdef ENABLE_ORT_BACKEND #ifdef ENABLE_ORT_BACKEND
backend_ = utils::make_unique<OrtBackend>(); backend_ = utils::make_unique<OrtBackend>();
backend_->benchmark_option_ = option.benchmark_option;
FDASSERT(backend_->Init(option), "Failed to initialize Backend::ORT."); FDASSERT(backend_->Init(option), "Failed to initialize Backend::ORT.");
#else #else
FDASSERT(false, FDASSERT(false,
@@ -351,6 +356,8 @@ void Runtime::CreateTrtBackend() {
trt_option.external_stream_ = option.external_stream_; trt_option.external_stream_ = option.external_stream_;
backend_ = utils::make_unique<TrtBackend>(); backend_ = utils::make_unique<TrtBackend>();
auto casted_backend = dynamic_cast<TrtBackend*>(backend_.get()); auto casted_backend = dynamic_cast<TrtBackend*>(backend_.get());
casted_backend->benchmark_option_ = option.benchmark_option;
if (option.model_format == ModelFormat::ONNX) { if (option.model_format == ModelFormat::ONNX) {
if (option.model_from_memory_) { if (option.model_from_memory_) {
FDASSERT(casted_backend->InitFromOnnx(option.model_file, trt_option), FDASSERT(casted_backend->InitFromOnnx(option.model_file, trt_option),
@@ -403,6 +410,8 @@ void Runtime::CreateLiteBackend() {
"LiteBackend only support model format of ModelFormat::PADDLE"); "LiteBackend only support model format of ModelFormat::PADDLE");
backend_ = utils::make_unique<LiteBackend>(); backend_ = utils::make_unique<LiteBackend>();
auto casted_backend = dynamic_cast<LiteBackend*>(backend_.get()); auto casted_backend = dynamic_cast<LiteBackend*>(backend_.get());
casted_backend->benchmark_option_ = option.benchmark_option;
FDASSERT(casted_backend->InitFromPaddle(option.model_file, option.params_file, FDASSERT(casted_backend->InitFromPaddle(option.model_file, option.params_file,
option.paddle_lite_option), option.paddle_lite_option),
"Load model from nb file failed while initializing LiteBackend."); "Load model from nb file failed while initializing LiteBackend.");

View File

@@ -95,6 +95,11 @@ struct FASTDEPLOY_DECL Runtime {
*/ */
bool Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors, bool Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
const RuntimeOption& _option); const RuntimeOption& _option);
/** \brief Get profile time of Runtime after the profile process is done.
*/
double GetProfileTime() {
return backend_->benchmark_result_.time_of_runtime;
}
private: private:
void CreateOrtBackend(); void CreateOrtBackend();

View File

@@ -32,6 +32,7 @@
#include "fastdeploy/runtime/backends/rknpu2/option.h" #include "fastdeploy/runtime/backends/rknpu2/option.h"
#include "fastdeploy/runtime/backends/sophgo/option.h" #include "fastdeploy/runtime/backends/sophgo/option.h"
#include "fastdeploy/runtime/backends/tensorrt/option.h" #include "fastdeploy/runtime/backends/tensorrt/option.h"
#include "fastdeploy/benchmark/option.h"
namespace fastdeploy { namespace fastdeploy {
@@ -347,6 +348,26 @@ struct FASTDEPLOY_DECL RuntimeOption {
float available_memory_proportion = 1.0, float available_memory_proportion = 1.0,
bool enable_half_partial = false); bool enable_half_partial = false);
/** \brief Set the profile mode as 'true'.
*
* \param[in] inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime.
* \param[in] repeat Repeat times for runtime inference.
* \param[in] warmup Warmup times for runtime inference.
*/
void EnableProfiling(bool inclue_h2d_d2h = false,
int repeat = 100, int warmup = 50) {
benchmark_option.enable_profile = true;
benchmark_option.warmup = warmup;
benchmark_option.repeats = repeat;
benchmark_option.include_h2d_d2h = inclue_h2d_d2h;
}
/** \brief Set the profile mode as 'false'.
*/
void DisableProfiling() {
benchmark_option.enable_profile = false;
}
Backend backend = Backend::UNKNOWN; Backend backend = Backend::UNKNOWN;
// for cpu inference // for cpu inference
@@ -419,6 +440,9 @@ struct FASTDEPLOY_DECL RuntimeOption {
bool model_from_memory_ = false; bool model_from_memory_ = false;
// format of input model // format of input model
ModelFormat model_format = ModelFormat::PADDLE; ModelFormat model_format = ModelFormat::PADDLE;
// Benchmark option
benchmark::BenchmarkOption benchmark_option;
}; };
} // namespace fastdeploy } // namespace fastdeploy

View File

@@ -54,6 +54,11 @@ class FastDeployModel:
def print_statis_info_of_runtime(self): def print_statis_info_of_runtime(self):
return self._model.print_statis_info_of_runtime() return self._model.print_statis_info_of_runtime()
def get_profile_time(self):
"""Get profile time of Runtime after the profile process is done.
"""
return self._model.get_profile_time()
@property @property
def runtime_option(self): def runtime_option(self):
return self._model.runtime_option if self._model is not None else None return self._model.runtime_option if self._model is not None else None

View File

@@ -144,6 +144,11 @@ class Runtime:
index, self.num_outputs) index, self.num_outputs)
return self._runtime.get_output_info(index) return self._runtime.get_output_info(index)
def get_profile_time(self):
"""Get profile time of Runtime after the profile process is done.
"""
return self._runtime.get_profile_time()
class RuntimeOption: class RuntimeOption:
"""Options for FastDeploy Runtime. """Options for FastDeploy Runtime.
@@ -552,6 +557,21 @@ class RuntimeOption:
available_memory_proportion, available_memory_proportion,
enable_half_partial) enable_half_partial)
def enable_profiling(self,
inclue_h2d_d2h=False,
repeat=100, warmup=50):
"""Set the profile mode as 'true'.
:param inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime.
:param repeat Repeat times for runtime inference.
:param warmup Warmup times for runtime inference.
"""
return self._option.enable_profiling(inclue_h2d_d2h, repeat, warmup)
def disable_profiling(self):
"""Set the profile mode as 'false'.
"""
return self._option.disable_profiling()
def __repr__(self): def __repr__(self):
attrs = dir(self._option) attrs = dir(self._option)
message = "RuntimeOption(\n" message = "RuntimeOption(\n"

View File

@@ -73,6 +73,7 @@ setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF")
setup_configs["ENABLE_ENCRYPTION"] = os.getenv("ENABLE_ENCRYPTION", "OFF") setup_configs["ENABLE_ENCRYPTION"] = os.getenv("ENABLE_ENCRYPTION", "OFF")
setup_configs["ENABLE_FLYCV"] = os.getenv("ENABLE_FLYCV", "OFF") setup_configs["ENABLE_FLYCV"] = os.getenv("ENABLE_FLYCV", "OFF")
setup_configs["ENABLE_TEXT"] = os.getenv("ENABLE_TEXT", "OFF") setup_configs["ENABLE_TEXT"] = os.getenv("ENABLE_TEXT", "OFF")
setup_configs["ENABLE_BENCHMARK"] = os.getenv("ENABLE_BENCHMARK", "OFF")
setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF") setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF")
setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF") setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF")
setup_configs["WITH_KUNLUNXIN"] = os.getenv("WITH_KUNLUNXIN", "OFF") setup_configs["WITH_KUNLUNXIN"] = os.getenv("WITH_KUNLUNXIN", "OFF")

View File

@@ -0,0 +1,80 @@
#!/bin/bash
set -e
set +x
# -------------------------------------------------------------------------------
# readonly global variables
# -------------------------------------------------------------------------------
readonly ROOT_PATH=$(pwd)
readonly BUILD_ROOT=build/Linux
readonly BUILD_DIR=${BUILD_ROOT}/x86_64
# -------------------------------------------------------------------------------
# tasks
# -------------------------------------------------------------------------------
__make_build_dir() {
if [ ! -d "${BUILD_DIR}" ]; then
echo "-- [INFO] BUILD_DIR: ${BUILD_DIR} not exists, setup manually ..."
if [ ! -d "${BUILD_ROOT}" ]; then
mkdir -p "${BUILD_ROOT}" && echo "-- [INFO] Created ${BUILD_ROOT} !"
fi
mkdir -p "${BUILD_DIR}" && echo "-- [INFO] Created ${BUILD_DIR} !"
else
echo "-- [INFO] Found BUILD_DIR: ${BUILD_DIR}"
fi
}
__check_cxx_envs() {
if [ $LDFLAGS ]; then
echo "-- [INFO] Found LDFLAGS: ${LDFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset LDFLAGS
fi
if [ $CPPFLAGS ]; then
echo "-- [INFO] Found CPPFLAGS: ${CPPFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPPFLAGS
fi
if [ $CPLUS_INCLUDE_PATH ]; then
echo "-- [INFO] Found CPLUS_INCLUDE_PATH: ${CPLUS_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPLUS_INCLUDE_PATH
fi
if [ $C_INCLUDE_PATH ]; then
echo "-- [INFO] Found C_INCLUDE_PATH: ${C_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset C_INCLUDE_PATH
fi
}
__build_fastdeploy_linux_x86_64_shared() {
local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install"
cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}"
cmake -DCMAKE_BUILD_TYPE=Release \
-DWITH_GPU=OFF \
-DENABLE_ORT_BACKEND=ON \
-DENABLE_PADDLE_BACKEND=ON \
-DENABLE_OPENVINO_BACKEND=ON \
-DENABLE_PADDLE2ONNX=ON \
-DENABLE_VISION=ON \
-DENABLE_BENCHMARK=ON \
-DBUILD_EXAMPLES=ON \
-DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \
-Wno-dev ../../.. && make -j8 && make install
echo "-- [INFO][built][x86_64]][${BUILD_DIR}/install]"
}
main() {
__make_build_dir
__check_cxx_envs
__build_fastdeploy_linux_x86_64_shared
exit 0
}
main
# Usage:
# ./scripts/linux/build_linux_x86_64_cpp_cpu.sh

View File

@@ -0,0 +1,83 @@
#!/bin/bash
set -e
set +x
# -------------------------------------------------------------------------------
# readonly global variables
# -------------------------------------------------------------------------------
readonly ROOT_PATH=$(pwd)
readonly BUILD_ROOT=build/Linux
readonly BUILD_DIR="${BUILD_ROOT}/x86_64_gpu"
# -------------------------------------------------------------------------------
# tasks
# -------------------------------------------------------------------------------
__make_build_dir() {
if [ ! -d "${BUILD_DIR}" ]; then
echo "-- [INFO] BUILD_DIR: ${BUILD_DIR} not exists, setup manually ..."
if [ ! -d "${BUILD_ROOT}" ]; then
mkdir -p "${BUILD_ROOT}" && echo "-- [INFO] Created ${BUILD_ROOT} !"
fi
mkdir -p "${BUILD_DIR}" && echo "-- [INFO] Created ${BUILD_DIR} !"
else
echo "-- [INFO] Found BUILD_DIR: ${BUILD_DIR}"
fi
}
__check_cxx_envs() {
if [ $LDFLAGS ]; then
echo "-- [INFO] Found LDFLAGS: ${LDFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset LDFLAGS
fi
if [ $CPPFLAGS ]; then
echo "-- [INFO] Found CPPFLAGS: ${CPPFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPPFLAGS
fi
if [ $CPLUS_INCLUDE_PATH ]; then
echo "-- [INFO] Found CPLUS_INCLUDE_PATH: ${CPLUS_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPLUS_INCLUDE_PATH
fi
if [ $C_INCLUDE_PATH ]; then
echo "-- [INFO] Found C_INCLUDE_PATH: ${C_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset C_INCLUDE_PATH
fi
}
__build_fastdeploy_linux_x86_64_gpu_shared() {
local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install"
cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}"
cmake -DCMAKE_BUILD_TYPE=Release \
-DWITH_GPU=ON \
-DTRT_DIRECTORY=${TRT_DIRECTORY} \
-DCUDA_DIRECTORY=${CUDA_DIRECTORY} \
-DENABLE_ORT_BACKEND=ON \
-DENABLE_TRT_BACKEND=ON \
-DENABLE_PADDLE_BACKEND=ON \
-DENABLE_OPENVINO_BACKEND=ON \
-DENABLE_PADDLE2ONNX=ON \
-DENABLE_VISION=ON \
-DENABLE_BENCHMARK=ON \
-DBUILD_EXAMPLES=ON \
-DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \
-Wno-dev ../../.. && make -j8 && make install
echo "-- [INFO][built][x86_64_gpu}][${BUILD_DIR}/install]"
}
main() {
__make_build_dir
__check_cxx_envs
__build_fastdeploy_linux_x86_64_gpu_shared
exit 0
}
main
# Usage:
# ./scripts/linux/build_linux_x86_64_cpp_gpu.sh

View File

@@ -0,0 +1,102 @@
#!/bin/bash
set -e
set +x
# -------------------------------------------------------------------------------
# readonly global variables
# -------------------------------------------------------------------------------
readonly ROOT_PATH=$(pwd)
readonly BUILD_ROOT=build/MacOSX
readonly OSX_ARCH=$1 # arm64, x86_64
readonly BUILD_DIR=${BUILD_ROOT}/${OSX_ARCH}
# -------------------------------------------------------------------------------
# tasks
# -------------------------------------------------------------------------------
__make_build_dir() {
if [ ! -d "${BUILD_DIR}" ]; then
echo "-- [INFO] BUILD_DIR: ${BUILD_DIR} not exists, setup manually ..."
if [ ! -d "${BUILD_ROOT}" ]; then
mkdir -p "${BUILD_ROOT}" && echo "-- [INFO] Created ${BUILD_ROOT} !"
fi
mkdir -p "${BUILD_DIR}" && echo "-- [INFO] Created ${BUILD_DIR} !"
else
echo "-- [INFO] Found BUILD_DIR: ${BUILD_DIR}"
fi
}
__check_cxx_envs() {
if [ $LDFLAGS ]; then
echo "-- [INFO] Found LDFLAGS: ${LDFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset LDFLAGS
fi
if [ $CPPFLAGS ]; then
echo "-- [INFO] Found CPPFLAGS: ${CPPFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPPFLAGS
fi
if [ $CPLUS_INCLUDE_PATH ]; then
echo "-- [INFO] Found CPLUS_INCLUDE_PATH: ${CPLUS_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPLUS_INCLUDE_PATH
fi
if [ $C_INCLUDE_PATH ]; then
echo "-- [INFO] Found C_INCLUDE_PATH: ${C_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset C_INCLUDE_PATH
fi
}
__build_fastdeploy_osx_arm64_shared() {
local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install"
cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}"
cmake -DCMAKE_BUILD_TYPE=MinSizeRel \
-DENABLE_ORT_BACKEND=ON \
-DENABLE_PADDLE2ONNX=ON \
-DENABLE_VISION=ON \
-DENABLE_BENCHMARK=ON \
-DBUILD_EXAMPLES=ON \
-DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \
-Wno-dev ../../.. && make -j8 && make install
echo "-- [INFO][built][${OSX_ARCH}][${BUILD_DIR}/install]"
}
__build_fastdeploy_osx_x86_64_shared() {
local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install"
cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}"
cmake -DCMAKE_BUILD_TYPE=MinSizeRel \
-DENABLE_ORT_BACKEND=ON \
-DENABLE_PADDLE_BACKEND=ON \
-DENABLE_OPENVINO_BACKEND=ON \
-DENABLE_PADDLE2ONNX=ON \
-DENABLE_VISION=ON \
-DENABLE_BENCHMARK=ON \
-DBUILD_EXAMPLES=ON \
-DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \
-Wno-dev ../../.. && make -j8 && make install
echo "-- [INFO][built][${OSX_ARCH}][${BUILD_DIR}/install]"
}
main() {
__make_build_dir
__check_cxx_envs
if [ "$OSX_ARCH" = "arm64" ]; then
__build_fastdeploy_osx_arm64_shared
else
__build_fastdeploy_osx_x86_64_shared
fi
exit 0
}
main
# Usage:
# ./scripts/macosx/build_macosx_cpp.sh arm64
# ./scripts/macosx/build_macosx_cpp.sh x86_64