From f73a538f619e23725af4898e3354858e1b2db15b Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Mon, 6 Feb 2023 14:29:35 +0800 Subject: [PATCH] [Backend] support bechmark mode for runtime and backend (#1201) * [backend] support bechmark mode for runtime and backend * [backend] support bechmark mode for runtime and backend * [pybind11] add benchmark methods pybind * [pybind11] add benchmark methods pybind * [Other] Update build scripts * [Other] Update cmake/summary.cmake * [Other] update build scripts * [Other] add ENABLE_BENCHMARK option -> setup.py * optimize backend time recording * optimize backend time recording * optimize trt backend time record * [backend] optimze backend_time recording for trt * [benchmark] remove redundant logs * fixed ov_backend confilct * [benchmark] fixed paddle_backend conflicts * [benchmark] fixed paddle_backend conflicts * [benchmark] fixed paddle_backend conflicts * [benchmark] remove use_gpu option from ort backend option * [benchmark] update benchmark_ppdet.py * [benchmark] update benchmark_ppcls.py * fixed lite backend conflicts * [Lite] fixed lite xpu * add benchmark macro * add RUNTIME_PROFILE_LOOP macros * add comments for RUNTIME_PROFILE macros * add comments for new apis * add comments for new apis * update benchmark_ppdet.py * afixed bugs * remove unused codes * optimize RUNTIME_PROFILE_LOOP macros * optimize RUNTIME_PROFILE_LOOP macros * add comments for benchmark option and result * add docs for benchmark namespace --- CMakeLists.txt | 1 + benchmark/.gitignore | 13 +++ benchmark/benchmark_ppcls.py | 67 +++++++----- benchmark/benchmark_ppdet.py | 75 ++++++++----- cmake/summary.cmake | 1 + fastdeploy/benchmark/benchmark.h | 86 +++++++++++++++ fastdeploy/benchmark/option.h | 47 ++++++++ fastdeploy/benchmark/results.h | 27 +++++ fastdeploy/core/config.h.in | 4 + fastdeploy/fastdeploy_model.cc | 23 +++- fastdeploy/fastdeploy_model.h | 13 ++- fastdeploy/pybind/fastdeploy_model.cc | 2 + fastdeploy/pybind/runtime.cc | 3 + fastdeploy/runtime/backends/backend.h | 66 +++++++++++- .../backends/lite/configure_hardware.cc | 18 +--- .../runtime/backends/lite/lite_backend.cc | 10 +- .../runtime/backends/lite/lite_backend.h | 0 fastdeploy/runtime/backends/lite/option.h | 2 +- .../runtime/backends/openvino/ov_backend.cc | 6 +- .../runtime/backends/openvino/ov_backend.h | 4 +- .../runtime/backends/ort/ort_backend.cc | 11 +- fastdeploy/runtime/backends/ort/ort_backend.h | 2 +- .../runtime/backends/paddle/paddle_backend.cc | 12 ++- .../runtime/backends/paddle/paddle_backend.h | 2 +- .../runtime/backends/tensorrt/trt_backend.cc | 8 +- fastdeploy/runtime/runtime.cc | 9 ++ fastdeploy/runtime/runtime.h | 5 + fastdeploy/runtime/runtime_option.h | 24 +++++ python/fastdeploy/model.py | 5 + python/fastdeploy/runtime.py | 20 ++++ python/setup.py | 1 + scripts/linux/build_linux_x86_64_cpp_cpu.sh | 80 ++++++++++++++ scripts/linux/build_linux_x86_64_cpp_gpu.sh | 83 ++++++++++++++ scripts/macosx/build_macosx_cpp.sh | 102 ++++++++++++++++++ 34 files changed, 741 insertions(+), 91 deletions(-) create mode 100644 benchmark/.gitignore create mode 100644 fastdeploy/benchmark/benchmark.h create mode 100644 fastdeploy/benchmark/option.h create mode 100644 fastdeploy/benchmark/results.h mode change 100755 => 100644 fastdeploy/runtime/backends/lite/lite_backend.h create mode 100755 scripts/linux/build_linux_x86_64_cpp_cpu.sh create mode 100755 scripts/linux/build_linux_x86_64_cpp_gpu.sh create mode 100755 scripts/macosx/build_macosx_cpp.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 44a71c2fb..2b4d5223c 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,6 +68,7 @@ option(ENABLE_TEXT "Whether to enable text models usage." OFF) option(ENABLE_FLYCV "Whether to enable flycv to boost image preprocess." OFF) option(ENABLE_CVCUDA "Whether to enable NVIDIA CV-CUDA to boost image preprocess." OFF) option(ENABLE_ENCRYPTION "Whether to enable ENCRYPTION." OFF) +option(ENABLE_BENCHMARK "Whether to enable Benchmark mode." OFF) option(WITH_ASCEND "Whether to compile for Huawei Ascend deploy." OFF) option(WITH_TIMVX "Whether to compile for TIMVX deploy." OFF) option(WITH_KUNLUNXIN "Whether to compile for KunlunXin XPU deploy." OFF) diff --git a/benchmark/.gitignore b/benchmark/.gitignore new file mode 100644 index 000000000..89096b6cd --- /dev/null +++ b/benchmark/.gitignore @@ -0,0 +1,13 @@ +*.tgz +*.zip +*.tar +*.tar.gz +*.tgz +*.jpg +*.png +*.jpeg +*.txt +*.log +yolov8_s_* +._yolov8_s_* +Mobile* \ No newline at end of file diff --git a/benchmark/benchmark_ppcls.py b/benchmark/benchmark_ppcls.py index 6b88658ee..a8219b028 100755 --- a/benchmark/benchmark_ppcls.py +++ b/benchmark/benchmark_ppcls.py @@ -17,7 +17,7 @@ import cv2 import os import numpy as np import time - +from tqdm import tqdm def parse_arguments(): import argparse @@ -35,11 +35,22 @@ def parse_arguments(): parser.add_argument( "--device_id", type=int, default=0, help="device(gpu) id") parser.add_argument( - "--iter_num", + "--profile_mode", + type=str, + default="runtime", + help="runtime or end2end.") + parser.add_argument( + "--repeat", required=True, type=int, - default=300, - help="number of iterations for computing performace.") + default=1000, + help="number of repeats for profiling.") + parser.add_argument( + "--warmup", + required=True, + type=int, + default=50, + help="number of warmup for profiling.") parser.add_argument( "--device", default="cpu", @@ -59,6 +70,11 @@ def parse_arguments(): type=ast.literal_eval, default=False, help="whether enable collect memory info") + parser.add_argument( + "--include_h2d_d2h", + type=ast.literal_eval, + default=False, + help="whether run profiling with h2d and d2h") args = parser.parse_args() return args @@ -68,6 +84,8 @@ def build_option(args): device = args.device backend = args.backend enable_trt_fp16 = args.enable_trt_fp16 + if args.profile_mode == "runtime": + option.enable_profiling(args.include_h2d_d2h, args.repeat, args.warmup) option.set_cpu_thread_num(args.cpu_num_thread) if device == "gpu": option.use_gpu() @@ -229,7 +247,6 @@ if __name__ == '__main__': gpu_id = args.device_id enable_collect_memory_info = args.enable_collect_memory_info dump_result = dict() - end2end_statis = list() cpu_mem = list() gpu_mem = list() gpu_util = list() @@ -257,19 +274,27 @@ if __name__ == '__main__': enable_gpu = args.device == "gpu" monitor = Monitor(enable_gpu, gpu_id) monitor.start() - - model.enable_record_time_of_runtime() + im_ori = cv2.imread(args.image) - for i in range(args.iter_num): - im = im_ori + if args.profile_mode == "runtime": + result = model.predict(im_ori) + profile_time = model.get_profile_time() + dump_result["runtime"] = profile_time * 1000 + f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) + print("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) + else: + # end2end + for i in range(args.warmup): + result = model.predict(im_ori) + start = time.time() - result = model.predict(im) - end2end_statis.append(time.time() - start) + for i in tqdm(range(args.repeat)): + result = model.predict(im_ori) + end = time.time() + dump_result["end2end"] = ((end - start) / args.repeat) * 1000.0 + f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"]))) + print("End2End(ms): {} \n".format(str(dump_result["end2end"]))) - runtime_statis = model.print_statis_info_of_runtime() - - warmup_iter = args.iter_num // 5 - end2end_statis_repeat = end2end_statis[warmup_iter:] if enable_collect_memory_info: monitor.stop() mem_info = monitor.output() @@ -279,14 +304,7 @@ if __name__ == '__main__': 'memory.used'] if 'gpu' in mem_info else 0 dump_result["gpu_util"] = mem_info['gpu'][ 'utilization.gpu'] if 'gpu' in mem_info else 0 - - dump_result["runtime"] = runtime_statis["avg_time"] * 1000 - dump_result["end2end"] = np.mean(end2end_statis_repeat) * 1000 - - f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) - f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"]))) - print("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) - print("End2End(ms): {} \n".format(str(dump_result["end2end"]))) + if enable_collect_memory_info: f.writelines("cpu_rss_mb: {} \n".format( str(dump_result["cpu_rss_mb"]))) @@ -297,7 +315,8 @@ if __name__ == '__main__': print("cpu_rss_mb: {} \n".format(str(dump_result["cpu_rss_mb"]))) print("gpu_rss_mb: {} \n".format(str(dump_result["gpu_rss_mb"]))) print("gpu_util: {} \n".format(str(dump_result["gpu_util"]))) - except: + except Exception as e: f.writelines("!!!!!Infer Failed\n") + raise e f.close() diff --git a/benchmark/benchmark_ppdet.py b/benchmark/benchmark_ppdet.py index 9133122b1..ab25a1da4 100755 --- a/benchmark/benchmark_ppdet.py +++ b/benchmark/benchmark_ppdet.py @@ -17,6 +17,7 @@ import cv2 import os import numpy as np import time +from sympy import EX from tqdm import tqdm def parse_arguments(): @@ -24,7 +25,7 @@ def parse_arguments(): import ast parser = argparse.ArgumentParser() parser.add_argument( - "--model", required=True, help="Path of PaddleDetection model.") + "--model", required=True, help="Path of PaddleClas model.") parser.add_argument( "--image", type=str, required=False, help="Path of test image file.") parser.add_argument( @@ -35,20 +36,31 @@ def parse_arguments(): parser.add_argument( "--device_id", type=int, default=0, help="device(gpu) id") parser.add_argument( - "--iter_num", + "--profile_mode", + type=str, + default="runtime", + help="runtime or end2end.") + parser.add_argument( + "--repeat", required=True, type=int, - default=300, - help="number of iterations for computing performace.") + default=1000, + help="number of repeats for profiling.") + parser.add_argument( + "--warmup", + required=True, + type=int, + default=50, + help="number of warmup for profiling.") parser.add_argument( "--device", default="cpu", - help="Type of inference device, support 'cpu', 'gpu', 'kunlunxin', 'ascend' etc.") + help="Type of inference device, support 'cpu' or 'gpu'.") parser.add_argument( "--backend", type=str, default="default", - help="inference backend, default, ort, ov, trt, paddle, paddle_trt, lite.") + help="inference backend, default, ort, ov, trt, paddle, paddle_trt.") parser.add_argument( "--enable_trt_fp16", type=ast.literal_eval, @@ -58,12 +70,17 @@ def parse_arguments(): "--enable_lite_fp16", type=ast.literal_eval, default=False, - help="whether enable fp16 in lite backend") + help="whether enable fp16 in Paddle Lite backend") parser.add_argument( "--enable_collect_memory_info", type=ast.literal_eval, default=False, help="whether enable collect memory info") + parser.add_argument( + "--include_h2d_d2h", + type=ast.literal_eval, + default=False, + help="whether run profiling with h2d and d2h") args = parser.parse_args() return args @@ -74,6 +91,8 @@ def build_option(args): backend = args.backend enable_trt_fp16 = args.enable_trt_fp16 enable_lite_fp16 = args.enable_lite_fp16 + if args.profile_mode == "runtime": + option.enable_profiling(args.include_h2d_d2h, args.repeat, args.warmup) option.set_cpu_thread_num(args.cpu_num_thread) if device == "gpu": option.use_gpu() @@ -266,8 +285,12 @@ if __name__ == '__main__': gpu_id = args.device_id enable_collect_memory_info = args.enable_collect_memory_info + enable_record_time_of_backend = args.enable_record_time_of_backend + backend_repeat = args.backend_repeat dump_result = dict() end2end_statis = list() + prepost_statis = list() + h2d_d2h_statis = list() cpu_mem = list() gpu_mem = list() gpu_util = list() @@ -317,18 +340,26 @@ if __name__ == '__main__': monitor = Monitor(enable_gpu, gpu_id) monitor.start() - model.enable_record_time_of_runtime() im_ori = cv2.imread(args.image) - for i in tqdm(range(args.iter_num)): - im = im_ori + if args.profile_mode == "runtime": + result = model.predict(im_ori) + profile_time = model.get_profile_time() + dump_result["runtime"] = profile_time * 1000 + f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) + print("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) + else: + # end2end + for i in range(args.warmup): + result = model.predict(im_ori) + start = time.time() - result = model.predict(im) - end2end_statis.append(time.time() - start) + for i in tqdm(range(args.repeat)): + result = model.predict(im_ori) + end = time.time() + dump_result["end2end"] = ((end - start) / args.repeat) * 1000.0 + f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"]))) + print("End2End(ms): {} \n".format(str(dump_result["end2end"]))) - runtime_statis = model.print_statis_info_of_runtime() - - warmup_iter = args.iter_num // 5 - end2end_statis_repeat = end2end_statis[warmup_iter:] if enable_collect_memory_info: monitor.stop() mem_info = monitor.output() @@ -338,14 +369,7 @@ if __name__ == '__main__': 'memory.used'] if 'gpu' in mem_info else 0 dump_result["gpu_util"] = mem_info['gpu'][ 'utilization.gpu'] if 'gpu' in mem_info else 0 - - dump_result["runtime"] = runtime_statis["avg_time"] * 1000 - dump_result["end2end"] = np.mean(end2end_statis_repeat) * 1000 - - f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) - f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"]))) - print("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) - print("End2End(ms): {} \n".format(str(dump_result["end2end"]))) + if enable_collect_memory_info: f.writelines("cpu_rss_mb: {} \n".format( str(dump_result["cpu_rss_mb"]))) @@ -356,7 +380,8 @@ if __name__ == '__main__': print("cpu_rss_mb: {} \n".format(str(dump_result["cpu_rss_mb"]))) print("gpu_rss_mb: {} \n".format(str(dump_result["gpu_rss_mb"]))) print("gpu_util: {} \n".format(str(dump_result["gpu_util"]))) - except: + except Exception as e: f.writelines("!!!!!Infer Failed\n") + raise e f.close() diff --git a/cmake/summary.cmake b/cmake/summary.cmake index faaacb417..ee2efccb7 100755 --- a/cmake/summary.cmake +++ b/cmake/summary.cmake @@ -39,6 +39,7 @@ function(fastdeploy_summary) message(STATUS " ENABLE_POROS_BACKEND : ${ENABLE_POROS_BACKEND}") message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}") message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}") + message(STATUS " ENABLE_BENCHMARK : ${ENABLE_BENCHMARK}") message(STATUS " WITH_GPU : ${WITH_GPU}") message(STATUS " WITH_ASCEND : ${WITH_ASCEND}") message(STATUS " WITH_TIMVX : ${WITH_TIMVX}") diff --git a/fastdeploy/benchmark/benchmark.h b/fastdeploy/benchmark/benchmark.h new file mode 100644 index 000000000..825fc4f54 --- /dev/null +++ b/fastdeploy/benchmark/benchmark.h @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "fastdeploy/core/config.h" +#include "fastdeploy/utils/utils.h" +#include "fastdeploy/utils/perf.h" +#include "fastdeploy/benchmark/option.h" +#include "fastdeploy/benchmark/results.h" + +#ifdef ENABLE_BENCHMARK + #define __RUNTIME_PROFILE_LOOP_BEGIN(option, base_loop) \ + int __p_loop = (base_loop); \ + const bool __p_enable_profile = option.enable_profile; \ + const bool __p_include_h2d_d2h = option.include_h2d_d2h; \ + const int __p_repeats = option.repeats; \ + const int __p_warmup = option.warmup; \ + if (__p_enable_profile && (!__p_include_h2d_d2h)) { \ + __p_loop = (__p_repeats) + (__p_warmup); \ + FDINFO << option << std::endl; \ + } \ + TimeCounter __p_tc; \ + bool __p_tc_start = false; \ + for (int __p_i = 0; __p_i < __p_loop; ++__p_i) { \ + if (__p_i >= (__p_warmup) && (!__p_tc_start)) { \ + __p_tc.Start(); \ + __p_tc_start = true; \ + } \ + + #define __RUNTIME_PROFILE_LOOP_END(result) \ + } \ + if ((__p_enable_profile && (!__p_include_h2d_d2h))) { \ + if (__p_tc_start) { \ + __p_tc.End(); \ + double __p_tc_duration = __p_tc.Duration(); \ + result.time_of_runtime = \ + __p_tc_duration / static_cast(__p_repeats); \ + } \ + } + + #define __RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(option, base_loop) \ + int __p_loop_h = (base_loop); \ + const bool __p_enable_profile_h = option.enable_profile; \ + const bool __p_include_h2d_d2h_h = option.include_h2d_d2h; \ + const int __p_repeats_h = option.repeats; \ + const int __p_warmup_h = option.warmup; \ + if (__p_enable_profile_h && __p_include_h2d_d2h_h) { \ + __p_loop_h = (__p_repeats_h) + (__p_warmup_h); \ + FDINFO << option << std::endl; \ + } \ + TimeCounter __p_tc_h; \ + bool __p_tc_start_h = false; \ + for (int __p_i_h = 0; __p_i_h < __p_loop_h; ++__p_i_h) { \ + if (__p_i_h >= (__p_warmup_h) && (!__p_tc_start_h)) { \ + __p_tc_h.Start(); \ + __p_tc_start_h = true; \ + } \ + + #define __RUNTIME_PROFILE_LOOP_H2D_D2H_END(result) \ + } \ + if ((__p_enable_profile_h && __p_include_h2d_d2h_h)) { \ + if (__p_tc_start_h) { \ + __p_tc_h.End(); \ + double __p_tc_duration_h = __p_tc_h.Duration(); \ + result.time_of_runtime = \ + __p_tc_duration_h / static_cast(__p_repeats_h); \ + } \ + } +#else + #define __RUNTIME_PROFILE_LOOP_BEGIN(option, base_loop) \ + for (int __p_i = 0; __p_i < (base_loop); ++ __p_i) { + #define __RUNTIME_PROFILE_LOOP_END(result) } + #define __RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(option, base_loop) \ + for (int __p_i_h = 0; __p_i_h < (base_loop); ++ __p_i_h) { + #define __RUNTIME_PROFILE_LOOP_H2D_D2H_END(result) } +#endif \ No newline at end of file diff --git a/fastdeploy/benchmark/option.h b/fastdeploy/benchmark/option.h new file mode 100644 index 000000000..6df9b473c --- /dev/null +++ b/fastdeploy/benchmark/option.h @@ -0,0 +1,47 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +namespace fastdeploy { + +/** \brief All C++ FastDeploy benchmark profile APIs are defined inside this namespace +* +*/ +namespace benchmark { + +/*! @brief Option object used to control the behavior of the benchmark profiling. + */ +struct BenchmarkOption { + int warmup = 50; ///< Warmup for backend inference. + int repeats = 100; ///< Repeats for backend inference. + bool enable_profile = false; ///< Whether to use profile or not. + bool include_h2d_d2h = false; ///< Whether to include time of H2D_D2H for time of runtime. + + friend std::ostream& operator<<( + std::ostream& output, const BenchmarkOption &option) { + if (!option.include_h2d_d2h) { + output << "Running profiling for Runtime " + << "without H2D and D2H, "; + } else { + output << "Running profiling for Runtime " + << "with H2D and D2H, "; + } + output << "Repeats: " << option.repeats << ", " + << "Warmup: " << option.warmup; + return output; + } +}; + +} // namespace benchmark +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/benchmark/results.h b/fastdeploy/benchmark/results.h new file mode 100644 index 000000000..ed5d003e3 --- /dev/null +++ b/fastdeploy/benchmark/results.h @@ -0,0 +1,27 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +namespace fastdeploy { +namespace benchmark { + +/*! @brief Result object used to record the time of runtime after benchmark profiling is done. + */ +struct BenchmarkResult { + ///< Means pure_backend_time+time_of_h2d_d2h(if include_h2d_d2h=true). + double time_of_runtime = 0.0f; +}; + +} // namespace benchmark +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/core/config.h.in b/fastdeploy/core/config.h.in index e6f202961..5593f9fd8 100755 --- a/fastdeploy/core/config.h.in +++ b/fastdeploy/core/config.h.in @@ -56,3 +56,7 @@ #ifndef ENABLE_TEXT #cmakedefine ENABLE_TEXT #endif + +#ifndef ENABLE_BENCHMARK +#cmakedefine ENABLE_BENCHMARK +#endif \ No newline at end of file diff --git a/fastdeploy/fastdeploy_model.cc b/fastdeploy/fastdeploy_model.cc index 9eff985fb..d909a6138 100644 --- a/fastdeploy/fastdeploy_model.cc +++ b/fastdeploy/fastdeploy_model.cc @@ -31,7 +31,8 @@ std::string Str(const std::vector& backends) { return oss.str(); } -bool IsSupported(const std::vector& backends, Backend backend) { +bool CheckBackendSupported(const std::vector& backends, + Backend backend) { for (size_t i = 0; i < backends.size(); ++i) { if (backends[i] == backend) { return true; @@ -40,6 +41,22 @@ bool IsSupported(const std::vector& backends, Backend backend) { return false; } +bool FastDeployModel::IsSupported(const std::vector& backends, + Backend backend) { +#ifdef ENABLE_BENCHMARK + if (runtime_option.benchmark_option.enable_profile) { + FDWARNING << "In benchmark mode, we don't check to see if " + << "the backend [" << backend + << "] is supported for current model!" + << std::endl; + return true; + } + return CheckBackendSupported(backends, backend); +#else + return CheckBackendSupported(backends, backend); +#endif +} + bool FastDeployModel::InitRuntimeWithSpecifiedBackend() { if (!IsBackendAvailable(runtime_option.backend)) { FDERROR << runtime_option.backend @@ -367,12 +384,13 @@ bool FastDeployModel::Infer(std::vector& input_tensors, tc.End(); if (time_of_runtime_.size() > 50000) { FDWARNING << "There are already 50000 records of runtime, will force to " - "disable record time of runtime now." + "disable record time of runtime now." << std::endl; enable_record_time_of_runtime_ = false; } time_of_runtime_.push_back(tc.Duration()); } + return ret; } @@ -416,6 +434,7 @@ std::map FastDeployModel::PrintStatisInfoOfRuntime() { statis_info_of_runtime_dict["warmup_iter"] = warmup_iter; statis_info_of_runtime_dict["avg_time"] = avg_time; statis_info_of_runtime_dict["iterations"] = time_of_runtime_.size(); + return statis_info_of_runtime_dict; } } // namespace fastdeploy diff --git a/fastdeploy/fastdeploy_model.h b/fastdeploy/fastdeploy_model.h index 698827cc2..037bb2192 100755 --- a/fastdeploy/fastdeploy_model.h +++ b/fastdeploy/fastdeploy_model.h @@ -75,7 +75,7 @@ class FASTDEPLOY_DECL FastDeployModel { return runtime_initialized_ && initialized; } - /** \brief This is a debug interface, used to record the time of backend runtime + /** \brief This is a debug interface, used to record the time of runtime (backend + h2d + d2h) * * example code @code * auto model = fastdeploy::vision::PPYOLOE("model.pdmodel", "model.pdiparams", "infer_cfg.yml"); @@ -98,7 +98,7 @@ class FASTDEPLOY_DECL FastDeployModel { enable_record_time_of_runtime_ = true; } - /** \brief Disable to record the time of backend runtime, see `EnableRecordTimeOfRuntime()` for more detail + /** \brief Disable to record the time of runtime, see `EnableRecordTimeOfRuntime()` for more detail */ virtual void DisableRecordTimeOfRuntime() { enable_record_time_of_runtime_ = false; @@ -113,6 +113,11 @@ class FASTDEPLOY_DECL FastDeployModel { virtual bool EnabledRecordTimeOfRuntime() { return enable_record_time_of_runtime_; } + /** \brief Get profile time of Runtime after the profile process is done. + */ + virtual double GetProfileTime() { + return runtime_->GetProfileTime(); + } /** \brief Release reused input/output buffers */ @@ -153,13 +158,13 @@ class FASTDEPLOY_DECL FastDeployModel { bool CreateTimVXBackend(); bool CreateKunlunXinBackend(); bool CreateASCENDBackend(); + bool IsSupported(const std::vector& backends, + Backend backend); std::shared_ptr runtime_; bool runtime_initialized_ = false; // whether to record inference time bool enable_record_time_of_runtime_ = false; - - // record inference time for backend std::vector time_of_runtime_; }; diff --git a/fastdeploy/pybind/fastdeploy_model.cc b/fastdeploy/pybind/fastdeploy_model.cc index 0b138fa60..e90619e37 100644 --- a/fastdeploy/pybind/fastdeploy_model.cc +++ b/fastdeploy/pybind/fastdeploy_model.cc @@ -30,6 +30,8 @@ void BindFDModel(pybind11::module& m) { &FastDeployModel::DisableRecordTimeOfRuntime) .def("print_statis_info_of_runtime", &FastDeployModel::PrintStatisInfoOfRuntime) + .def("get_profile_time", + &FastDeployModel::GetProfileTime) .def("initialized", &FastDeployModel::Initialized) .def_readwrite("runtime_option", &FastDeployModel::runtime_option) .def_readwrite("valid_cpu_backends", &FastDeployModel::valid_cpu_backends) diff --git a/fastdeploy/pybind/runtime.cc b/fastdeploy/pybind/runtime.cc index 6c5c65bc2..7eeb0fdc2 100644 --- a/fastdeploy/pybind/runtime.cc +++ b/fastdeploy/pybind/runtime.cc @@ -77,6 +77,8 @@ void BindRuntime(pybind11::module& m) { .def("set_ipu_config", &RuntimeOption::SetIpuConfig) .def("delete_paddle_backend_pass", &RuntimeOption::DeletePaddleBackendPass) + .def("enable_profiling", &RuntimeOption::EnableProfiling) + .def("disable_profiling", &RuntimeOption::DisableProfiling) .def("disable_paddle_trt_ops", &RuntimeOption::DisablePaddleTrtOPs) .def_readwrite("model_file", &RuntimeOption::model_file) .def_readwrite("params_file", &RuntimeOption::params_file) @@ -217,6 +219,7 @@ void BindRuntime(pybind11::module& m) { .def("num_outputs", &Runtime::NumOutputs) .def("get_input_info", &Runtime::GetInputInfo) .def("get_output_info", &Runtime::GetOutputInfo) + .def("get_profile_time", &Runtime::GetProfileTime) .def_readonly("option", &Runtime::option); pybind11::enum_(m, "Backend", pybind11::arithmetic(), diff --git a/fastdeploy/runtime/backends/backend.h b/fastdeploy/runtime/backends/backend.h index 88a8e78a0..802db6fa1 100644 --- a/fastdeploy/runtime/backends/backend.h +++ b/fastdeploy/runtime/backends/backend.h @@ -22,6 +22,7 @@ #include "fastdeploy/core/fd_tensor.h" #include "fastdeploy/core/fd_type.h" #include "fastdeploy/runtime/runtime_option.h" +#include "fastdeploy/benchmark/benchmark.h" namespace fastdeploy { @@ -79,7 +80,6 @@ class BaseBackend { virtual bool Infer(std::vector& inputs, std::vector* outputs, bool copy_to_fd = true) = 0; - // Optional: For those backends which can share memory // while creating multiple inference engines with same model file virtual std::unique_ptr Clone(RuntimeOption &runtime_option, @@ -88,6 +88,70 @@ class BaseBackend { FDERROR << "Clone no support" << std::endl; return nullptr; } + + benchmark::BenchmarkOption benchmark_option_; + benchmark::BenchmarkResult benchmark_result_; }; +/** \brief Macros for Runtime benchmark profiling. + * The param 'base_loop' for 'RUNTIME_PROFILE_LOOP_BEGIN' + * indicates that the least number of times the loop + * will repeat when profiling mode is not enabled. + * In most cases, the value should be 1, i.e., results are + * obtained by running the inference process once, when + * the profile mode is turned off, such as ONNX Runtime, + * OpenVINO, TensorRT, Paddle Inference, Paddle Lite, + * RKNPU2, SOPHGO etc. + * + * example code @code + * // OpenVINOBackend::Infer + * RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN + * // do something .... + * RUNTIME_PROFILE_LOOP_BEGIN(1) + * // The codes which wrapped by 'BEGIN(1) ~ END' scope + * // will only run once when profiling mode is not enabled. + * request_.infer(); + * RUNTIME_PROFILE_LOOP_END + * // do something .... + * RUNTIME_PROFILE_LOOP_H2D_D2H_END + * + * @endcode In this case, No global variables inside a function + * are wrapped by BEGIN and END, which may be required for + * subsequent tasks. But, some times we need to set 'base_loop' + * as 0, such as POROS. + * + * * example code @code + * // PorosBackend::Infer + * RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN + * // do something .... + * RUNTIME_PROFILE_LOOP_BEGIN(0) // set 'base_loop' as 0 + * // The codes which wrapped by 'BEGIN(0) ~ END' scope + * // will not run when profiling mode is not enabled. + * auto poros_outputs = _poros_module->forward(poros_inputs); + * RUNTIME_PROFILE_LOOP_END + * // Run another inference beyond the scope of 'BEGIN ~ END' + * // to get valid outputs for subsequent tasks. + * auto poros_outputs = _poros_module->forward(poros_inputs); + * // do something .... will use 'poros_outputs' ... + * if (poros_outputs.isTensor()) { + * // ... + * } + * RUNTIME_PROFILE_LOOP_H2D_D2H_END + * + * @endcode In this case, 'poros_outputs' inside a function + * are wrapped by BEGIN and END, which may be required for + * subsequent tasks. So, we set 'base_loop' as 0 and lanuch + * another infer to get the valid outputs beyond the scope + * of 'BEGIN ~ END' for subsequent tasks. + */ + +#define RUNTIME_PROFILE_LOOP_BEGIN(base_loop) \ + __RUNTIME_PROFILE_LOOP_BEGIN(benchmark_option_, (base_loop)) +#define RUNTIME_PROFILE_LOOP_END \ + __RUNTIME_PROFILE_LOOP_END(benchmark_result_) +#define RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN \ + __RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(benchmark_option_, 1) +#define RUNTIME_PROFILE_LOOP_H2D_D2H_END \ + __RUNTIME_PROFILE_LOOP_H2D_D2H_END(benchmark_result_) + } // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/lite/configure_hardware.cc b/fastdeploy/runtime/backends/lite/configure_hardware.cc index 7c7a9993c..7ac60383f 100644 --- a/fastdeploy/runtime/backends/lite/configure_hardware.cc +++ b/fastdeploy/runtime/backends/lite/configure_hardware.cc @@ -13,23 +13,6 @@ // limitations under the License. #include "fastdeploy/runtime/backends/lite/lite_backend.h" -// https://github.com/PaddlePaddle/Paddle-Lite/issues/8290 -// When compiling the FastDeploy dynamic library, namely, -// WITH_STATIC_LIB=OFF, and depending on the Paddle Lite -// static library, you need to include the fake registration -// codes of Paddle Lite. When you compile the FastDeploy static -// library and depends on the Paddle Lite static library, -// WITH_STATIC_LIB=ON, you do not need to include the fake -// registration codes for Paddle Lite, but wait until you -// use the FastDeploy static library. -#if (defined(WITH_LITE_STATIC) && (!defined(WITH_STATIC_LIB))) -#warning You are compiling the FastDeploy dynamic library with \ -Paddle Lite static lib We will automatically add some registration \ -codes for ops, kernels and passes for Paddle Lite. -#include "paddle_use_kernels.h" // NOLINT -#include "paddle_use_ops.h" // NOLINT -#include "paddle_use_passes.h" // NOLINT -#endif #include @@ -156,4 +139,5 @@ void LiteBackend::ConfigureNNAdapter(const LiteBackendOption& option) { config_.set_nnadapter_dynamic_shape_info(option.nnadapter_dynamic_shape_info); } + } // namespace fastdeploy diff --git a/fastdeploy/runtime/backends/lite/lite_backend.cc b/fastdeploy/runtime/backends/lite/lite_backend.cc index f9d47a7a5..39cf2ebdd 100644 --- a/fastdeploy/runtime/backends/lite/lite_backend.cc +++ b/fastdeploy/runtime/backends/lite/lite_backend.cc @@ -100,7 +100,7 @@ bool LiteBackend::InitFromPaddle(const std::string& model_file, auto shape = tensor->shape(); info.shape.assign(shape.begin(), shape.end()); info.name = output_names[i]; - if (!option_.device == Device::KUNLUNXIN) { + if (option_.device != Device::KUNLUNXIN) { info.dtype = LiteDataTypeToFD(tensor->precision()); } outputs_desc_.emplace_back(info); @@ -136,6 +136,8 @@ bool LiteBackend::Infer(std::vector& inputs, << inputs_desc_.size() << ")." << std::endl; return false; } + + RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN for (size_t i = 0; i < inputs.size(); ++i) { auto iter = inputs_order_.find(inputs[i].name); if (iter == inputs_order_.end()) { @@ -143,6 +145,7 @@ bool LiteBackend::Infer(std::vector& inputs, << " in loaded model." << std::endl; return false; } + auto tensor = predictor_->GetInput(iter->second); // Adjust dims only, allocate lazy. tensor->Resize(inputs[i].shape); @@ -174,8 +177,10 @@ bool LiteBackend::Infer(std::vector& inputs, FDASSERT(false, "Unexpected data type of %d.", inputs[i].dtype); } } - + + RUNTIME_PROFILE_LOOP_BEGIN(1) predictor_->Run(); + RUNTIME_PROFILE_LOOP_END outputs->resize(outputs_desc_.size()); for (size_t i = 0; i < outputs_desc_.size(); ++i) { @@ -188,6 +193,7 @@ bool LiteBackend::Infer(std::vector& inputs, memcpy((*outputs)[i].MutableData(), tensor->data(), (*outputs)[i].Nbytes()); } + RUNTIME_PROFILE_LOOP_H2D_D2H_END return true; } diff --git a/fastdeploy/runtime/backends/lite/lite_backend.h b/fastdeploy/runtime/backends/lite/lite_backend.h old mode 100755 new mode 100644 diff --git a/fastdeploy/runtime/backends/lite/option.h b/fastdeploy/runtime/backends/lite/option.h index 879cb3472..d94b32251 100755 --- a/fastdeploy/runtime/backends/lite/option.h +++ b/fastdeploy/runtime/backends/lite/option.h @@ -81,4 +81,4 @@ struct LiteBackendOption { bool kunlunxin_adaptive_seqlen = false; bool kunlunxin_enable_multi_stream = false; }; -} // namespace fastdeploy +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/runtime/backends/openvino/ov_backend.cc b/fastdeploy/runtime/backends/openvino/ov_backend.cc index 1d3134506..7f569f92c 100644 --- a/fastdeploy/runtime/backends/openvino/ov_backend.cc +++ b/fastdeploy/runtime/backends/openvino/ov_backend.cc @@ -375,6 +375,7 @@ bool OpenVINOBackend::Infer(std::vector& inputs, return false; } + RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN for (size_t i = 0; i < inputs.size(); ++i) { ov::Shape shape(inputs[i].shape.begin(), inputs[i].shape.end()); ov::Tensor ov_tensor(FDDataTypeToOV(inputs[i].dtype), shape, @@ -382,7 +383,9 @@ bool OpenVINOBackend::Infer(std::vector& inputs, request_.set_tensor(inputs[i].name, ov_tensor); } + RUNTIME_PROFILE_LOOP_BEGIN(1) request_.infer(); + RUNTIME_PROFILE_LOOP_END outputs->resize(output_infos_.size()); for (size_t i = 0; i < output_infos_.size(); ++i) { @@ -403,6 +406,7 @@ bool OpenVINOBackend::Infer(std::vector& inputs, out_tensor.data(), Device::CPU); } } + RUNTIME_PROFILE_LOOP_H2D_D2H_END return true; } @@ -419,4 +423,4 @@ std::unique_ptr OpenVINOBackend::Clone( return new_backend; } -} // namespace fastdeploy +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/runtime/backends/openvino/ov_backend.h b/fastdeploy/runtime/backends/openvino/ov_backend.h index 12b282c8e..a27f17480 100644 --- a/fastdeploy/runtime/backends/openvino/ov_backend.h +++ b/fastdeploy/runtime/backends/openvino/ov_backend.h @@ -49,7 +49,7 @@ class OpenVINOBackend : public BaseBackend { std::unique_ptr Clone(RuntimeOption &runtime_option, void* stream = nullptr, int device_id = -1) override; - + private: bool InitFromPaddle(const std::string& model_file, const std::string& params_file, @@ -70,4 +70,4 @@ class OpenVINOBackend : public BaseBackend { std::vector output_infos_; }; -} // namespace fastdeploy +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/runtime/backends/ort/ort_backend.cc b/fastdeploy/runtime/backends/ort/ort_backend.cc index 70cb18121..58c449cc6 100644 --- a/fastdeploy/runtime/backends/ort/ort_backend.cc +++ b/fastdeploy/runtime/backends/ort/ort_backend.cc @@ -13,9 +13,6 @@ // limitations under the License. #include "fastdeploy/runtime/backends/ort/ort_backend.h" - -#include - #include "fastdeploy/core/float16.h" #include "fastdeploy/runtime/backends/ort/ops/adaptive_pool2d.h" #include "fastdeploy/runtime/backends/ort/ops/multiclass_nms.h" @@ -25,6 +22,9 @@ #include "paddle2onnx/converter.h" #endif +#include + + namespace fastdeploy { std::vector OrtBackend::custom_operators_ = @@ -258,6 +258,7 @@ bool OrtBackend::Infer(std::vector& inputs, } // from FDTensor to Ort Inputs + RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN for (size_t i = 0; i < inputs.size(); ++i) { auto ort_value = CreateOrtValue(inputs[i], option_.device == Device::GPU); binding_->BindInput(inputs[i].name.c_str(), ort_value); @@ -270,12 +271,14 @@ bool OrtBackend::Infer(std::vector& inputs, } // Inference with inputs + RUNTIME_PROFILE_LOOP_BEGIN(1) try { session_.Run({}, *(binding_.get())); } catch (const std::exception& e) { FDERROR << "Failed to Infer: " << e.what() << std::endl; return false; } + RUNTIME_PROFILE_LOOP_END // Convert result after inference std::vector ort_outputs = binding_->GetOutputValues(); @@ -284,7 +287,7 @@ bool OrtBackend::Infer(std::vector& inputs, OrtValueToFDTensor(ort_outputs[i], &((*outputs)[i]), outputs_desc_[i].name, copy_to_fd); } - + RUNTIME_PROFILE_LOOP_H2D_D2H_END return true; } diff --git a/fastdeploy/runtime/backends/ort/ort_backend.h b/fastdeploy/runtime/backends/ort/ort_backend.h index 61308b9da..e0caf48a3 100644 --- a/fastdeploy/runtime/backends/ort/ort_backend.h +++ b/fastdeploy/runtime/backends/ort/ort_backend.h @@ -54,7 +54,7 @@ class OrtBackend : public BaseBackend { std::vector GetOutputInfos() override; static std::vector custom_operators_; void InitCustomOperators(); - + private: bool InitFromPaddle(const std::string& model_buffer, const std::string& params_buffer, diff --git a/fastdeploy/runtime/backends/paddle/paddle_backend.cc b/fastdeploy/runtime/backends/paddle/paddle_backend.cc index 90bd27682..e210293b0 100644 --- a/fastdeploy/runtime/backends/paddle/paddle_backend.cc +++ b/fastdeploy/runtime/backends/paddle/paddle_backend.cc @@ -221,14 +221,17 @@ bool PaddleBackend::Infer(std::vector& inputs, << inputs_desc_.size() << ")." << std::endl; return false; } - + + RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN for (size_t i = 0; i < inputs.size(); ++i) { auto handle = predictor_->GetInputHandle(inputs[i].name); ShareTensorFromFDTensor(handle.get(), inputs[i]); } - + + RUNTIME_PROFILE_LOOP_BEGIN(1) predictor_->Run(); - + RUNTIME_PROFILE_LOOP_END + // output share backend memory only support CPU or GPU if (option_.use_ipu) { copy_to_fd = true; @@ -241,6 +244,7 @@ bool PaddleBackend::Infer(std::vector& inputs, } PaddleTensorToFDTensor(handle, &((*outputs)[i]), copy_to_fd); } + RUNTIME_PROFILE_LOOP_H2D_D2H_END return true; } @@ -381,4 +385,4 @@ void PaddleBackend::CollectShapeRun( predictor->Run(); } -} // namespace fastdeploy +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/runtime/backends/paddle/paddle_backend.h b/fastdeploy/runtime/backends/paddle/paddle_backend.h index 8cde22cfd..02c430ade 100755 --- a/fastdeploy/runtime/backends/paddle/paddle_backend.h +++ b/fastdeploy/runtime/backends/paddle/paddle_backend.h @@ -89,4 +89,4 @@ class PaddleBackend : public BaseBackend { std::vector inputs_desc_; std::vector outputs_desc_; }; -} // namespace fastdeploy +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/runtime/backends/tensorrt/trt_backend.cc b/fastdeploy/runtime/backends/tensorrt/trt_backend.cc index 6972cf8ed..d64a946f7 100644 --- a/fastdeploy/runtime/backends/tensorrt/trt_backend.cc +++ b/fastdeploy/runtime/backends/tensorrt/trt_backend.cc @@ -287,14 +287,18 @@ bool TrtBackend::Infer(std::vector& inputs, BuildTrtEngine(); } + RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN cudaSetDevice(option_.gpu_id); SetInputs(inputs); AllocateOutputsBuffer(outputs, copy_to_fd); - + + RUNTIME_PROFILE_LOOP_BEGIN(1) if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) { FDERROR << "Failed to Infer with TensorRT." << std::endl; return false; } + RUNTIME_PROFILE_LOOP_END + for (size_t i = 0; i < outputs->size(); ++i) { // if the final output tensor's dtype is different from the model output // tensor's dtype, then we need cast the data to the final output's dtype @@ -335,7 +339,7 @@ bool TrtBackend::Infer(std::vector& inputs, FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess, "[ERROR] Error occurs while sync cuda stream."); } - + RUNTIME_PROFILE_LOOP_H2D_D2H_END return true; } diff --git a/fastdeploy/runtime/runtime.cc b/fastdeploy/runtime/runtime.cc index 0669f52dc..1ed82891a 100644 --- a/fastdeploy/runtime/runtime.cc +++ b/fastdeploy/runtime/runtime.cc @@ -275,6 +275,8 @@ void Runtime::CreatePaddleBackend() { #endif backend_ = utils::make_unique(); auto casted_backend = dynamic_cast(backend_.get()); + casted_backend->benchmark_option_ = option.benchmark_option; + if (pd_option.model_from_memory_) { FDASSERT(casted_backend->InitFromPaddle(option.model_file, option.params_file, pd_option), @@ -303,6 +305,7 @@ void Runtime::CreatePaddleBackend() { void Runtime::CreateOpenVINOBackend() { #ifdef ENABLE_OPENVINO_BACKEND backend_ = utils::make_unique(); + backend_->benchmark_option_ = option.benchmark_option; FDASSERT(backend_->Init(option), "Failed to initialize OpenVINOBackend."); #else FDASSERT(false, @@ -316,6 +319,8 @@ void Runtime::CreateOpenVINOBackend() { void Runtime::CreateOrtBackend() { #ifdef ENABLE_ORT_BACKEND backend_ = utils::make_unique(); + backend_->benchmark_option_ = option.benchmark_option; + FDASSERT(backend_->Init(option), "Failed to initialize Backend::ORT."); #else FDASSERT(false, @@ -351,6 +356,8 @@ void Runtime::CreateTrtBackend() { trt_option.external_stream_ = option.external_stream_; backend_ = utils::make_unique(); auto casted_backend = dynamic_cast(backend_.get()); + casted_backend->benchmark_option_ = option.benchmark_option; + if (option.model_format == ModelFormat::ONNX) { if (option.model_from_memory_) { FDASSERT(casted_backend->InitFromOnnx(option.model_file, trt_option), @@ -403,6 +410,8 @@ void Runtime::CreateLiteBackend() { "LiteBackend only support model format of ModelFormat::PADDLE"); backend_ = utils::make_unique(); auto casted_backend = dynamic_cast(backend_.get()); + casted_backend->benchmark_option_ = option.benchmark_option; + FDASSERT(casted_backend->InitFromPaddle(option.model_file, option.params_file, option.paddle_lite_option), "Load model from nb file failed while initializing LiteBackend."); diff --git a/fastdeploy/runtime/runtime.h b/fastdeploy/runtime/runtime.h index 22a09c355..6e7dc9629 100755 --- a/fastdeploy/runtime/runtime.h +++ b/fastdeploy/runtime/runtime.h @@ -95,6 +95,11 @@ struct FASTDEPLOY_DECL Runtime { */ bool Compile(std::vector>& prewarm_tensors, const RuntimeOption& _option); + /** \brief Get profile time of Runtime after the profile process is done. + */ + double GetProfileTime() { + return backend_->benchmark_result_.time_of_runtime; + } private: void CreateOrtBackend(); diff --git a/fastdeploy/runtime/runtime_option.h b/fastdeploy/runtime/runtime_option.h index e51fb9be2..64222f359 100644 --- a/fastdeploy/runtime/runtime_option.h +++ b/fastdeploy/runtime/runtime_option.h @@ -32,6 +32,7 @@ #include "fastdeploy/runtime/backends/rknpu2/option.h" #include "fastdeploy/runtime/backends/sophgo/option.h" #include "fastdeploy/runtime/backends/tensorrt/option.h" +#include "fastdeploy/benchmark/option.h" namespace fastdeploy { @@ -346,6 +347,26 @@ struct FASTDEPLOY_DECL RuntimeOption { void SetIpuConfig(bool enable_fp16 = false, int replica_num = 1, float available_memory_proportion = 1.0, bool enable_half_partial = false); + + /** \brief Set the profile mode as 'true'. + * + * \param[in] inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime. + * \param[in] repeat Repeat times for runtime inference. + * \param[in] warmup Warmup times for runtime inference. + */ + void EnableProfiling(bool inclue_h2d_d2h = false, + int repeat = 100, int warmup = 50) { + benchmark_option.enable_profile = true; + benchmark_option.warmup = warmup; + benchmark_option.repeats = repeat; + benchmark_option.include_h2d_d2h = inclue_h2d_d2h; + } + + /** \brief Set the profile mode as 'false'. + */ + void DisableProfiling() { + benchmark_option.enable_profile = false; + } Backend backend = Backend::UNKNOWN; @@ -419,6 +440,9 @@ struct FASTDEPLOY_DECL RuntimeOption { bool model_from_memory_ = false; // format of input model ModelFormat model_format = ModelFormat::PADDLE; + + // Benchmark option + benchmark::BenchmarkOption benchmark_option; }; } // namespace fastdeploy diff --git a/python/fastdeploy/model.py b/python/fastdeploy/model.py index 59833f775..224cbafdf 100644 --- a/python/fastdeploy/model.py +++ b/python/fastdeploy/model.py @@ -54,6 +54,11 @@ class FastDeployModel: def print_statis_info_of_runtime(self): return self._model.print_statis_info_of_runtime() + def get_profile_time(self): + """Get profile time of Runtime after the profile process is done. + """ + return self._model.get_profile_time() + @property def runtime_option(self): return self._model.runtime_option if self._model is not None else None diff --git a/python/fastdeploy/runtime.py b/python/fastdeploy/runtime.py index d864a8897..b3da670bb 100755 --- a/python/fastdeploy/runtime.py +++ b/python/fastdeploy/runtime.py @@ -144,6 +144,11 @@ class Runtime: index, self.num_outputs) return self._runtime.get_output_info(index) + def get_profile_time(self): + """Get profile time of Runtime after the profile process is done. + """ + return self._runtime.get_profile_time() + class RuntimeOption: """Options for FastDeploy Runtime. @@ -552,6 +557,21 @@ class RuntimeOption: available_memory_proportion, enable_half_partial) + def enable_profiling(self, + inclue_h2d_d2h=False, + repeat=100, warmup=50): + """Set the profile mode as 'true'. + :param inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime. + :param repeat Repeat times for runtime inference. + :param warmup Warmup times for runtime inference. + """ + return self._option.enable_profiling(inclue_h2d_d2h, repeat, warmup) + + def disable_profiling(self): + """Set the profile mode as 'false'. + """ + return self._option.disable_profiling() + def __repr__(self): attrs = dir(self._option) message = "RuntimeOption(\n" diff --git a/python/setup.py b/python/setup.py index d1b02254e..01246283a 100755 --- a/python/setup.py +++ b/python/setup.py @@ -73,6 +73,7 @@ setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF") setup_configs["ENABLE_ENCRYPTION"] = os.getenv("ENABLE_ENCRYPTION", "OFF") setup_configs["ENABLE_FLYCV"] = os.getenv("ENABLE_FLYCV", "OFF") setup_configs["ENABLE_TEXT"] = os.getenv("ENABLE_TEXT", "OFF") +setup_configs["ENABLE_BENCHMARK"] = os.getenv("ENABLE_BENCHMARK", "OFF") setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF") setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF") setup_configs["WITH_KUNLUNXIN"] = os.getenv("WITH_KUNLUNXIN", "OFF") diff --git a/scripts/linux/build_linux_x86_64_cpp_cpu.sh b/scripts/linux/build_linux_x86_64_cpp_cpu.sh new file mode 100755 index 000000000..e3ff7964b --- /dev/null +++ b/scripts/linux/build_linux_x86_64_cpp_cpu.sh @@ -0,0 +1,80 @@ +#!/bin/bash +set -e +set +x + +# ------------------------------------------------------------------------------- +# readonly global variables +# ------------------------------------------------------------------------------- +readonly ROOT_PATH=$(pwd) +readonly BUILD_ROOT=build/Linux +readonly BUILD_DIR=${BUILD_ROOT}/x86_64 + +# ------------------------------------------------------------------------------- +# tasks +# ------------------------------------------------------------------------------- +__make_build_dir() { + if [ ! -d "${BUILD_DIR}" ]; then + echo "-- [INFO] BUILD_DIR: ${BUILD_DIR} not exists, setup manually ..." + if [ ! -d "${BUILD_ROOT}" ]; then + mkdir -p "${BUILD_ROOT}" && echo "-- [INFO] Created ${BUILD_ROOT} !" + fi + mkdir -p "${BUILD_DIR}" && echo "-- [INFO] Created ${BUILD_DIR} !" + else + echo "-- [INFO] Found BUILD_DIR: ${BUILD_DIR}" + fi +} + +__check_cxx_envs() { + if [ $LDFLAGS ]; then + echo "-- [INFO] Found LDFLAGS: ${LDFLAGS}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset LDFLAGS + fi + if [ $CPPFLAGS ]; then + echo "-- [INFO] Found CPPFLAGS: ${CPPFLAGS}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset CPPFLAGS + fi + if [ $CPLUS_INCLUDE_PATH ]; then + echo "-- [INFO] Found CPLUS_INCLUDE_PATH: ${CPLUS_INCLUDE_PATH}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset CPLUS_INCLUDE_PATH + fi + if [ $C_INCLUDE_PATH ]; then + echo "-- [INFO] Found C_INCLUDE_PATH: ${C_INCLUDE_PATH}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset C_INCLUDE_PATH + fi +} + +__build_fastdeploy_linux_x86_64_shared() { + + local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install" + cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}" + + cmake -DCMAKE_BUILD_TYPE=Release \ + -DWITH_GPU=OFF \ + -DENABLE_ORT_BACKEND=ON \ + -DENABLE_PADDLE_BACKEND=ON \ + -DENABLE_OPENVINO_BACKEND=ON \ + -DENABLE_PADDLE2ONNX=ON \ + -DENABLE_VISION=ON \ + -DENABLE_BENCHMARK=ON \ + -DBUILD_EXAMPLES=ON \ + -DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \ + -Wno-dev ../../.. && make -j8 && make install + + echo "-- [INFO][built][x86_64]][${BUILD_DIR}/install]" +} + +main() { + __make_build_dir + __check_cxx_envs + __build_fastdeploy_linux_x86_64_shared + exit 0 +} + +main + +# Usage: +# ./scripts/linux/build_linux_x86_64_cpp_cpu.sh diff --git a/scripts/linux/build_linux_x86_64_cpp_gpu.sh b/scripts/linux/build_linux_x86_64_cpp_gpu.sh new file mode 100755 index 000000000..6f2b4ed7d --- /dev/null +++ b/scripts/linux/build_linux_x86_64_cpp_gpu.sh @@ -0,0 +1,83 @@ +#!/bin/bash +set -e +set +x + +# ------------------------------------------------------------------------------- +# readonly global variables +# ------------------------------------------------------------------------------- +readonly ROOT_PATH=$(pwd) +readonly BUILD_ROOT=build/Linux +readonly BUILD_DIR="${BUILD_ROOT}/x86_64_gpu" + +# ------------------------------------------------------------------------------- +# tasks +# ------------------------------------------------------------------------------- +__make_build_dir() { + if [ ! -d "${BUILD_DIR}" ]; then + echo "-- [INFO] BUILD_DIR: ${BUILD_DIR} not exists, setup manually ..." + if [ ! -d "${BUILD_ROOT}" ]; then + mkdir -p "${BUILD_ROOT}" && echo "-- [INFO] Created ${BUILD_ROOT} !" + fi + mkdir -p "${BUILD_DIR}" && echo "-- [INFO] Created ${BUILD_DIR} !" + else + echo "-- [INFO] Found BUILD_DIR: ${BUILD_DIR}" + fi +} + +__check_cxx_envs() { + if [ $LDFLAGS ]; then + echo "-- [INFO] Found LDFLAGS: ${LDFLAGS}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset LDFLAGS + fi + if [ $CPPFLAGS ]; then + echo "-- [INFO] Found CPPFLAGS: ${CPPFLAGS}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset CPPFLAGS + fi + if [ $CPLUS_INCLUDE_PATH ]; then + echo "-- [INFO] Found CPLUS_INCLUDE_PATH: ${CPLUS_INCLUDE_PATH}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset CPLUS_INCLUDE_PATH + fi + if [ $C_INCLUDE_PATH ]; then + echo "-- [INFO] Found C_INCLUDE_PATH: ${C_INCLUDE_PATH}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset C_INCLUDE_PATH + fi +} + +__build_fastdeploy_linux_x86_64_gpu_shared() { + + local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install" + cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}" + + cmake -DCMAKE_BUILD_TYPE=Release \ + -DWITH_GPU=ON \ + -DTRT_DIRECTORY=${TRT_DIRECTORY} \ + -DCUDA_DIRECTORY=${CUDA_DIRECTORY} \ + -DENABLE_ORT_BACKEND=ON \ + -DENABLE_TRT_BACKEND=ON \ + -DENABLE_PADDLE_BACKEND=ON \ + -DENABLE_OPENVINO_BACKEND=ON \ + -DENABLE_PADDLE2ONNX=ON \ + -DENABLE_VISION=ON \ + -DENABLE_BENCHMARK=ON \ + -DBUILD_EXAMPLES=ON \ + -DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \ + -Wno-dev ../../.. && make -j8 && make install + + echo "-- [INFO][built][x86_64_gpu}][${BUILD_DIR}/install]" +} + +main() { + __make_build_dir + __check_cxx_envs + __build_fastdeploy_linux_x86_64_gpu_shared + exit 0 +} + +main + +# Usage: +# ./scripts/linux/build_linux_x86_64_cpp_gpu.sh diff --git a/scripts/macosx/build_macosx_cpp.sh b/scripts/macosx/build_macosx_cpp.sh new file mode 100755 index 000000000..4d8e08726 --- /dev/null +++ b/scripts/macosx/build_macosx_cpp.sh @@ -0,0 +1,102 @@ +#!/bin/bash +set -e +set +x + +# ------------------------------------------------------------------------------- +# readonly global variables +# ------------------------------------------------------------------------------- +readonly ROOT_PATH=$(pwd) +readonly BUILD_ROOT=build/MacOSX +readonly OSX_ARCH=$1 # arm64, x86_64 +readonly BUILD_DIR=${BUILD_ROOT}/${OSX_ARCH} + +# ------------------------------------------------------------------------------- +# tasks +# ------------------------------------------------------------------------------- +__make_build_dir() { + if [ ! -d "${BUILD_DIR}" ]; then + echo "-- [INFO] BUILD_DIR: ${BUILD_DIR} not exists, setup manually ..." + if [ ! -d "${BUILD_ROOT}" ]; then + mkdir -p "${BUILD_ROOT}" && echo "-- [INFO] Created ${BUILD_ROOT} !" + fi + mkdir -p "${BUILD_DIR}" && echo "-- [INFO] Created ${BUILD_DIR} !" + else + echo "-- [INFO] Found BUILD_DIR: ${BUILD_DIR}" + fi +} + +__check_cxx_envs() { + if [ $LDFLAGS ]; then + echo "-- [INFO] Found LDFLAGS: ${LDFLAGS}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset LDFLAGS + fi + if [ $CPPFLAGS ]; then + echo "-- [INFO] Found CPPFLAGS: ${CPPFLAGS}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset CPPFLAGS + fi + if [ $CPLUS_INCLUDE_PATH ]; then + echo "-- [INFO] Found CPLUS_INCLUDE_PATH: ${CPLUS_INCLUDE_PATH}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset CPLUS_INCLUDE_PATH + fi + if [ $C_INCLUDE_PATH ]; then + echo "-- [INFO] Found C_INCLUDE_PATH: ${C_INCLUDE_PATH}, \c" + echo "unset it before crossing compiling ${BUILD_DIR}" + unset C_INCLUDE_PATH + fi +} + +__build_fastdeploy_osx_arm64_shared() { + + local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install" + cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}" + + cmake -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DENABLE_ORT_BACKEND=ON \ + -DENABLE_PADDLE2ONNX=ON \ + -DENABLE_VISION=ON \ + -DENABLE_BENCHMARK=ON \ + -DBUILD_EXAMPLES=ON \ + -DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \ + -Wno-dev ../../.. && make -j8 && make install + + echo "-- [INFO][built][${OSX_ARCH}][${BUILD_DIR}/install]" +} + +__build_fastdeploy_osx_x86_64_shared() { + + local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install" + cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}" + + cmake -DCMAKE_BUILD_TYPE=MinSizeRel \ + -DENABLE_ORT_BACKEND=ON \ + -DENABLE_PADDLE_BACKEND=ON \ + -DENABLE_OPENVINO_BACKEND=ON \ + -DENABLE_PADDLE2ONNX=ON \ + -DENABLE_VISION=ON \ + -DENABLE_BENCHMARK=ON \ + -DBUILD_EXAMPLES=ON \ + -DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \ + -Wno-dev ../../.. && make -j8 && make install + + echo "-- [INFO][built][${OSX_ARCH}][${BUILD_DIR}/install]" +} + +main() { + __make_build_dir + __check_cxx_envs + if [ "$OSX_ARCH" = "arm64" ]; then + __build_fastdeploy_osx_arm64_shared + else + __build_fastdeploy_osx_x86_64_shared + fi + exit 0 +} + +main + +# Usage: +# ./scripts/macosx/build_macosx_cpp.sh arm64 +# ./scripts/macosx/build_macosx_cpp.sh x86_64