[Backend] support bechmark mode for runtime and backend (#1201)

* [backend] support bechmark mode for runtime and backend

* [backend] support bechmark mode for runtime and backend

* [pybind11] add benchmark methods pybind

* [pybind11] add benchmark methods pybind

* [Other] Update build scripts

* [Other] Update cmake/summary.cmake

* [Other] update build scripts

* [Other] add ENABLE_BENCHMARK option -> setup.py

* optimize backend time recording

* optimize backend time recording

* optimize trt backend time record

* [backend] optimze backend_time recording for trt

* [benchmark] remove redundant logs

* fixed ov_backend confilct

* [benchmark] fixed paddle_backend conflicts

* [benchmark] fixed paddle_backend conflicts

* [benchmark] fixed paddle_backend conflicts

* [benchmark] remove use_gpu option from ort backend option

* [benchmark] update benchmark_ppdet.py

* [benchmark] update benchmark_ppcls.py

* fixed lite backend conflicts

* [Lite] fixed lite xpu

* add benchmark macro

* add RUNTIME_PROFILE_LOOP macros

* add comments for RUNTIME_PROFILE macros

* add comments for new apis

* add comments for new apis

* update benchmark_ppdet.py

* afixed bugs

* remove unused codes

* optimize RUNTIME_PROFILE_LOOP macros

* optimize RUNTIME_PROFILE_LOOP macros

* add comments for benchmark option and result

* add docs for benchmark namespace
This commit is contained in:
DefTruth
2023-02-06 14:29:35 +08:00
committed by GitHub
parent 42d14e7119
commit f73a538f61
34 changed files with 741 additions and 91 deletions

View File

@@ -68,6 +68,7 @@ option(ENABLE_TEXT "Whether to enable text models usage." OFF)
option(ENABLE_FLYCV "Whether to enable flycv to boost image preprocess." OFF)
option(ENABLE_CVCUDA "Whether to enable NVIDIA CV-CUDA to boost image preprocess." OFF)
option(ENABLE_ENCRYPTION "Whether to enable ENCRYPTION." OFF)
option(ENABLE_BENCHMARK "Whether to enable Benchmark mode." OFF)
option(WITH_ASCEND "Whether to compile for Huawei Ascend deploy." OFF)
option(WITH_TIMVX "Whether to compile for TIMVX deploy." OFF)
option(WITH_KUNLUNXIN "Whether to compile for KunlunXin XPU deploy." OFF)

13
benchmark/.gitignore vendored Normal file
View File

@@ -0,0 +1,13 @@
*.tgz
*.zip
*.tar
*.tar.gz
*.tgz
*.jpg
*.png
*.jpeg
*.txt
*.log
yolov8_s_*
._yolov8_s_*
Mobile*

View File

@@ -17,7 +17,7 @@ import cv2
import os
import numpy as np
import time
from tqdm import tqdm
def parse_arguments():
import argparse
@@ -35,11 +35,22 @@ def parse_arguments():
parser.add_argument(
"--device_id", type=int, default=0, help="device(gpu) id")
parser.add_argument(
"--iter_num",
"--profile_mode",
type=str,
default="runtime",
help="runtime or end2end.")
parser.add_argument(
"--repeat",
required=True,
type=int,
default=300,
help="number of iterations for computing performace.")
default=1000,
help="number of repeats for profiling.")
parser.add_argument(
"--warmup",
required=True,
type=int,
default=50,
help="number of warmup for profiling.")
parser.add_argument(
"--device",
default="cpu",
@@ -59,6 +70,11 @@ def parse_arguments():
type=ast.literal_eval,
default=False,
help="whether enable collect memory info")
parser.add_argument(
"--include_h2d_d2h",
type=ast.literal_eval,
default=False,
help="whether run profiling with h2d and d2h")
args = parser.parse_args()
return args
@@ -68,6 +84,8 @@ def build_option(args):
device = args.device
backend = args.backend
enable_trt_fp16 = args.enable_trt_fp16
if args.profile_mode == "runtime":
option.enable_profiling(args.include_h2d_d2h, args.repeat, args.warmup)
option.set_cpu_thread_num(args.cpu_num_thread)
if device == "gpu":
option.use_gpu()
@@ -229,7 +247,6 @@ if __name__ == '__main__':
gpu_id = args.device_id
enable_collect_memory_info = args.enable_collect_memory_info
dump_result = dict()
end2end_statis = list()
cpu_mem = list()
gpu_mem = list()
gpu_util = list()
@@ -258,18 +275,26 @@ if __name__ == '__main__':
monitor = Monitor(enable_gpu, gpu_id)
monitor.start()
model.enable_record_time_of_runtime()
im_ori = cv2.imread(args.image)
for i in range(args.iter_num):
im = im_ori
if args.profile_mode == "runtime":
result = model.predict(im_ori)
profile_time = model.get_profile_time()
dump_result["runtime"] = profile_time * 1000
f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
print("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
else:
# end2end
for i in range(args.warmup):
result = model.predict(im_ori)
start = time.time()
result = model.predict(im)
end2end_statis.append(time.time() - start)
for i in tqdm(range(args.repeat)):
result = model.predict(im_ori)
end = time.time()
dump_result["end2end"] = ((end - start) / args.repeat) * 1000.0
f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"])))
print("End2End(ms): {} \n".format(str(dump_result["end2end"])))
runtime_statis = model.print_statis_info_of_runtime()
warmup_iter = args.iter_num // 5
end2end_statis_repeat = end2end_statis[warmup_iter:]
if enable_collect_memory_info:
monitor.stop()
mem_info = monitor.output()
@@ -280,13 +305,6 @@ if __name__ == '__main__':
dump_result["gpu_util"] = mem_info['gpu'][
'utilization.gpu'] if 'gpu' in mem_info else 0
dump_result["runtime"] = runtime_statis["avg_time"] * 1000
dump_result["end2end"] = np.mean(end2end_statis_repeat) * 1000
f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"])))
print("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
print("End2End(ms): {} \n".format(str(dump_result["end2end"])))
if enable_collect_memory_info:
f.writelines("cpu_rss_mb: {} \n".format(
str(dump_result["cpu_rss_mb"])))
@@ -297,7 +315,8 @@ if __name__ == '__main__':
print("cpu_rss_mb: {} \n".format(str(dump_result["cpu_rss_mb"])))
print("gpu_rss_mb: {} \n".format(str(dump_result["gpu_rss_mb"])))
print("gpu_util: {} \n".format(str(dump_result["gpu_util"])))
except:
except Exception as e:
f.writelines("!!!!!Infer Failed\n")
raise e
f.close()

View File

@@ -17,6 +17,7 @@ import cv2
import os
import numpy as np
import time
from sympy import EX
from tqdm import tqdm
def parse_arguments():
@@ -24,7 +25,7 @@ def parse_arguments():
import ast
parser = argparse.ArgumentParser()
parser.add_argument(
"--model", required=True, help="Path of PaddleDetection model.")
"--model", required=True, help="Path of PaddleClas model.")
parser.add_argument(
"--image", type=str, required=False, help="Path of test image file.")
parser.add_argument(
@@ -35,20 +36,31 @@ def parse_arguments():
parser.add_argument(
"--device_id", type=int, default=0, help="device(gpu) id")
parser.add_argument(
"--iter_num",
"--profile_mode",
type=str,
default="runtime",
help="runtime or end2end.")
parser.add_argument(
"--repeat",
required=True,
type=int,
default=300,
help="number of iterations for computing performace.")
default=1000,
help="number of repeats for profiling.")
parser.add_argument(
"--warmup",
required=True,
type=int,
default=50,
help="number of warmup for profiling.")
parser.add_argument(
"--device",
default="cpu",
help="Type of inference device, support 'cpu', 'gpu', 'kunlunxin', 'ascend' etc.")
help="Type of inference device, support 'cpu' or 'gpu'.")
parser.add_argument(
"--backend",
type=str,
default="default",
help="inference backend, default, ort, ov, trt, paddle, paddle_trt, lite.")
help="inference backend, default, ort, ov, trt, paddle, paddle_trt.")
parser.add_argument(
"--enable_trt_fp16",
type=ast.literal_eval,
@@ -58,12 +70,17 @@ def parse_arguments():
"--enable_lite_fp16",
type=ast.literal_eval,
default=False,
help="whether enable fp16 in lite backend")
help="whether enable fp16 in Paddle Lite backend")
parser.add_argument(
"--enable_collect_memory_info",
type=ast.literal_eval,
default=False,
help="whether enable collect memory info")
parser.add_argument(
"--include_h2d_d2h",
type=ast.literal_eval,
default=False,
help="whether run profiling with h2d and d2h")
args = parser.parse_args()
return args
@@ -74,6 +91,8 @@ def build_option(args):
backend = args.backend
enable_trt_fp16 = args.enable_trt_fp16
enable_lite_fp16 = args.enable_lite_fp16
if args.profile_mode == "runtime":
option.enable_profiling(args.include_h2d_d2h, args.repeat, args.warmup)
option.set_cpu_thread_num(args.cpu_num_thread)
if device == "gpu":
option.use_gpu()
@@ -266,8 +285,12 @@ if __name__ == '__main__':
gpu_id = args.device_id
enable_collect_memory_info = args.enable_collect_memory_info
enable_record_time_of_backend = args.enable_record_time_of_backend
backend_repeat = args.backend_repeat
dump_result = dict()
end2end_statis = list()
prepost_statis = list()
h2d_d2h_statis = list()
cpu_mem = list()
gpu_mem = list()
gpu_util = list()
@@ -317,18 +340,26 @@ if __name__ == '__main__':
monitor = Monitor(enable_gpu, gpu_id)
monitor.start()
model.enable_record_time_of_runtime()
im_ori = cv2.imread(args.image)
for i in tqdm(range(args.iter_num)):
im = im_ori
if args.profile_mode == "runtime":
result = model.predict(im_ori)
profile_time = model.get_profile_time()
dump_result["runtime"] = profile_time * 1000
f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
print("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
else:
# end2end
for i in range(args.warmup):
result = model.predict(im_ori)
start = time.time()
result = model.predict(im)
end2end_statis.append(time.time() - start)
for i in tqdm(range(args.repeat)):
result = model.predict(im_ori)
end = time.time()
dump_result["end2end"] = ((end - start) / args.repeat) * 1000.0
f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"])))
print("End2End(ms): {} \n".format(str(dump_result["end2end"])))
runtime_statis = model.print_statis_info_of_runtime()
warmup_iter = args.iter_num // 5
end2end_statis_repeat = end2end_statis[warmup_iter:]
if enable_collect_memory_info:
monitor.stop()
mem_info = monitor.output()
@@ -339,13 +370,6 @@ if __name__ == '__main__':
dump_result["gpu_util"] = mem_info['gpu'][
'utilization.gpu'] if 'gpu' in mem_info else 0
dump_result["runtime"] = runtime_statis["avg_time"] * 1000
dump_result["end2end"] = np.mean(end2end_statis_repeat) * 1000
f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"])))
print("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
print("End2End(ms): {} \n".format(str(dump_result["end2end"])))
if enable_collect_memory_info:
f.writelines("cpu_rss_mb: {} \n".format(
str(dump_result["cpu_rss_mb"])))
@@ -356,7 +380,8 @@ if __name__ == '__main__':
print("cpu_rss_mb: {} \n".format(str(dump_result["cpu_rss_mb"])))
print("gpu_rss_mb: {} \n".format(str(dump_result["gpu_rss_mb"])))
print("gpu_util: {} \n".format(str(dump_result["gpu_util"])))
except:
except Exception as e:
f.writelines("!!!!!Infer Failed\n")
raise e
f.close()

View File

@@ -39,6 +39,7 @@ function(fastdeploy_summary)
message(STATUS " ENABLE_POROS_BACKEND : ${ENABLE_POROS_BACKEND}")
message(STATUS " ENABLE_TRT_BACKEND : ${ENABLE_TRT_BACKEND}")
message(STATUS " ENABLE_OPENVINO_BACKEND : ${ENABLE_OPENVINO_BACKEND}")
message(STATUS " ENABLE_BENCHMARK : ${ENABLE_BENCHMARK}")
message(STATUS " WITH_GPU : ${WITH_GPU}")
message(STATUS " WITH_ASCEND : ${WITH_ASCEND}")
message(STATUS " WITH_TIMVX : ${WITH_TIMVX}")

View File

@@ -0,0 +1,86 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "fastdeploy/core/config.h"
#include "fastdeploy/utils/utils.h"
#include "fastdeploy/utils/perf.h"
#include "fastdeploy/benchmark/option.h"
#include "fastdeploy/benchmark/results.h"
#ifdef ENABLE_BENCHMARK
#define __RUNTIME_PROFILE_LOOP_BEGIN(option, base_loop) \
int __p_loop = (base_loop); \
const bool __p_enable_profile = option.enable_profile; \
const bool __p_include_h2d_d2h = option.include_h2d_d2h; \
const int __p_repeats = option.repeats; \
const int __p_warmup = option.warmup; \
if (__p_enable_profile && (!__p_include_h2d_d2h)) { \
__p_loop = (__p_repeats) + (__p_warmup); \
FDINFO << option << std::endl; \
} \
TimeCounter __p_tc; \
bool __p_tc_start = false; \
for (int __p_i = 0; __p_i < __p_loop; ++__p_i) { \
if (__p_i >= (__p_warmup) && (!__p_tc_start)) { \
__p_tc.Start(); \
__p_tc_start = true; \
} \
#define __RUNTIME_PROFILE_LOOP_END(result) \
} \
if ((__p_enable_profile && (!__p_include_h2d_d2h))) { \
if (__p_tc_start) { \
__p_tc.End(); \
double __p_tc_duration = __p_tc.Duration(); \
result.time_of_runtime = \
__p_tc_duration / static_cast<double>(__p_repeats); \
} \
}
#define __RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(option, base_loop) \
int __p_loop_h = (base_loop); \
const bool __p_enable_profile_h = option.enable_profile; \
const bool __p_include_h2d_d2h_h = option.include_h2d_d2h; \
const int __p_repeats_h = option.repeats; \
const int __p_warmup_h = option.warmup; \
if (__p_enable_profile_h && __p_include_h2d_d2h_h) { \
__p_loop_h = (__p_repeats_h) + (__p_warmup_h); \
FDINFO << option << std::endl; \
} \
TimeCounter __p_tc_h; \
bool __p_tc_start_h = false; \
for (int __p_i_h = 0; __p_i_h < __p_loop_h; ++__p_i_h) { \
if (__p_i_h >= (__p_warmup_h) && (!__p_tc_start_h)) { \
__p_tc_h.Start(); \
__p_tc_start_h = true; \
} \
#define __RUNTIME_PROFILE_LOOP_H2D_D2H_END(result) \
} \
if ((__p_enable_profile_h && __p_include_h2d_d2h_h)) { \
if (__p_tc_start_h) { \
__p_tc_h.End(); \
double __p_tc_duration_h = __p_tc_h.Duration(); \
result.time_of_runtime = \
__p_tc_duration_h / static_cast<double>(__p_repeats_h); \
} \
}
#else
#define __RUNTIME_PROFILE_LOOP_BEGIN(option, base_loop) \
for (int __p_i = 0; __p_i < (base_loop); ++ __p_i) {
#define __RUNTIME_PROFILE_LOOP_END(result) }
#define __RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(option, base_loop) \
for (int __p_i_h = 0; __p_i_h < (base_loop); ++ __p_i_h) {
#define __RUNTIME_PROFILE_LOOP_H2D_D2H_END(result) }
#endif

View File

@@ -0,0 +1,47 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace fastdeploy {
/** \brief All C++ FastDeploy benchmark profile APIs are defined inside this namespace
*
*/
namespace benchmark {
/*! @brief Option object used to control the behavior of the benchmark profiling.
*/
struct BenchmarkOption {
int warmup = 50; ///< Warmup for backend inference.
int repeats = 100; ///< Repeats for backend inference.
bool enable_profile = false; ///< Whether to use profile or not.
bool include_h2d_d2h = false; ///< Whether to include time of H2D_D2H for time of runtime.
friend std::ostream& operator<<(
std::ostream& output, const BenchmarkOption &option) {
if (!option.include_h2d_d2h) {
output << "Running profiling for Runtime "
<< "without H2D and D2H, ";
} else {
output << "Running profiling for Runtime "
<< "with H2D and D2H, ";
}
output << "Repeats: " << option.repeats << ", "
<< "Warmup: " << option.warmup;
return output;
}
};
} // namespace benchmark
} // namespace fastdeploy

View File

@@ -0,0 +1,27 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace fastdeploy {
namespace benchmark {
/*! @brief Result object used to record the time of runtime after benchmark profiling is done.
*/
struct BenchmarkResult {
///< Means pure_backend_time+time_of_h2d_d2h(if include_h2d_d2h=true).
double time_of_runtime = 0.0f;
};
} // namespace benchmark
} // namespace fastdeploy

View File

@@ -56,3 +56,7 @@
#ifndef ENABLE_TEXT
#cmakedefine ENABLE_TEXT
#endif
#ifndef ENABLE_BENCHMARK
#cmakedefine ENABLE_BENCHMARK
#endif

View File

@@ -31,7 +31,8 @@ std::string Str(const std::vector<Backend>& backends) {
return oss.str();
}
bool IsSupported(const std::vector<Backend>& backends, Backend backend) {
bool CheckBackendSupported(const std::vector<Backend>& backends,
Backend backend) {
for (size_t i = 0; i < backends.size(); ++i) {
if (backends[i] == backend) {
return true;
@@ -40,6 +41,22 @@ bool IsSupported(const std::vector<Backend>& backends, Backend backend) {
return false;
}
bool FastDeployModel::IsSupported(const std::vector<Backend>& backends,
Backend backend) {
#ifdef ENABLE_BENCHMARK
if (runtime_option.benchmark_option.enable_profile) {
FDWARNING << "In benchmark mode, we don't check to see if "
<< "the backend [" << backend
<< "] is supported for current model!"
<< std::endl;
return true;
}
return CheckBackendSupported(backends, backend);
#else
return CheckBackendSupported(backends, backend);
#endif
}
bool FastDeployModel::InitRuntimeWithSpecifiedBackend() {
if (!IsBackendAvailable(runtime_option.backend)) {
FDERROR << runtime_option.backend
@@ -373,6 +390,7 @@ bool FastDeployModel::Infer(std::vector<FDTensor>& input_tensors,
}
time_of_runtime_.push_back(tc.Duration());
}
return ret;
}
@@ -416,6 +434,7 @@ std::map<std::string, float> FastDeployModel::PrintStatisInfoOfRuntime() {
statis_info_of_runtime_dict["warmup_iter"] = warmup_iter;
statis_info_of_runtime_dict["avg_time"] = avg_time;
statis_info_of_runtime_dict["iterations"] = time_of_runtime_.size();
return statis_info_of_runtime_dict;
}
} // namespace fastdeploy

View File

@@ -75,7 +75,7 @@ class FASTDEPLOY_DECL FastDeployModel {
return runtime_initialized_ && initialized;
}
/** \brief This is a debug interface, used to record the time of backend runtime
/** \brief This is a debug interface, used to record the time of runtime (backend + h2d + d2h)
*
* example code @code
* auto model = fastdeploy::vision::PPYOLOE("model.pdmodel", "model.pdiparams", "infer_cfg.yml");
@@ -98,7 +98,7 @@ class FASTDEPLOY_DECL FastDeployModel {
enable_record_time_of_runtime_ = true;
}
/** \brief Disable to record the time of backend runtime, see `EnableRecordTimeOfRuntime()` for more detail
/** \brief Disable to record the time of runtime, see `EnableRecordTimeOfRuntime()` for more detail
*/
virtual void DisableRecordTimeOfRuntime() {
enable_record_time_of_runtime_ = false;
@@ -113,6 +113,11 @@ class FASTDEPLOY_DECL FastDeployModel {
virtual bool EnabledRecordTimeOfRuntime() {
return enable_record_time_of_runtime_;
}
/** \brief Get profile time of Runtime after the profile process is done.
*/
virtual double GetProfileTime() {
return runtime_->GetProfileTime();
}
/** \brief Release reused input/output buffers
*/
@@ -153,13 +158,13 @@ class FASTDEPLOY_DECL FastDeployModel {
bool CreateTimVXBackend();
bool CreateKunlunXinBackend();
bool CreateASCENDBackend();
bool IsSupported(const std::vector<Backend>& backends,
Backend backend);
std::shared_ptr<Runtime> runtime_;
bool runtime_initialized_ = false;
// whether to record inference time
bool enable_record_time_of_runtime_ = false;
// record inference time for backend
std::vector<double> time_of_runtime_;
};

View File

@@ -30,6 +30,8 @@ void BindFDModel(pybind11::module& m) {
&FastDeployModel::DisableRecordTimeOfRuntime)
.def("print_statis_info_of_runtime",
&FastDeployModel::PrintStatisInfoOfRuntime)
.def("get_profile_time",
&FastDeployModel::GetProfileTime)
.def("initialized", &FastDeployModel::Initialized)
.def_readwrite("runtime_option", &FastDeployModel::runtime_option)
.def_readwrite("valid_cpu_backends", &FastDeployModel::valid_cpu_backends)

View File

@@ -77,6 +77,8 @@ void BindRuntime(pybind11::module& m) {
.def("set_ipu_config", &RuntimeOption::SetIpuConfig)
.def("delete_paddle_backend_pass",
&RuntimeOption::DeletePaddleBackendPass)
.def("enable_profiling", &RuntimeOption::EnableProfiling)
.def("disable_profiling", &RuntimeOption::DisableProfiling)
.def("disable_paddle_trt_ops", &RuntimeOption::DisablePaddleTrtOPs)
.def_readwrite("model_file", &RuntimeOption::model_file)
.def_readwrite("params_file", &RuntimeOption::params_file)
@@ -217,6 +219,7 @@ void BindRuntime(pybind11::module& m) {
.def("num_outputs", &Runtime::NumOutputs)
.def("get_input_info", &Runtime::GetInputInfo)
.def("get_output_info", &Runtime::GetOutputInfo)
.def("get_profile_time", &Runtime::GetProfileTime)
.def_readonly("option", &Runtime::option);
pybind11::enum_<Backend>(m, "Backend", pybind11::arithmetic(),

View File

@@ -22,6 +22,7 @@
#include "fastdeploy/core/fd_tensor.h"
#include "fastdeploy/core/fd_type.h"
#include "fastdeploy/runtime/runtime_option.h"
#include "fastdeploy/benchmark/benchmark.h"
namespace fastdeploy {
@@ -79,7 +80,6 @@ class BaseBackend {
virtual bool Infer(std::vector<FDTensor>& inputs,
std::vector<FDTensor>* outputs,
bool copy_to_fd = true) = 0;
// Optional: For those backends which can share memory
// while creating multiple inference engines with same model file
virtual std::unique_ptr<BaseBackend> Clone(RuntimeOption &runtime_option,
@@ -88,6 +88,70 @@ class BaseBackend {
FDERROR << "Clone no support" << std::endl;
return nullptr;
}
benchmark::BenchmarkOption benchmark_option_;
benchmark::BenchmarkResult benchmark_result_;
};
/** \brief Macros for Runtime benchmark profiling.
* The param 'base_loop' for 'RUNTIME_PROFILE_LOOP_BEGIN'
* indicates that the least number of times the loop
* will repeat when profiling mode is not enabled.
* In most cases, the value should be 1, i.e., results are
* obtained by running the inference process once, when
* the profile mode is turned off, such as ONNX Runtime,
* OpenVINO, TensorRT, Paddle Inference, Paddle Lite,
* RKNPU2, SOPHGO etc.
*
* example code @code
* // OpenVINOBackend::Infer
* RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
* // do something ....
* RUNTIME_PROFILE_LOOP_BEGIN(1)
* // The codes which wrapped by 'BEGIN(1) ~ END' scope
* // will only run once when profiling mode is not enabled.
* request_.infer();
* RUNTIME_PROFILE_LOOP_END
* // do something ....
* RUNTIME_PROFILE_LOOP_H2D_D2H_END
*
* @endcode In this case, No global variables inside a function
* are wrapped by BEGIN and END, which may be required for
* subsequent tasks. But, some times we need to set 'base_loop'
* as 0, such as POROS.
*
* * example code @code
* // PorosBackend::Infer
* RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
* // do something ....
* RUNTIME_PROFILE_LOOP_BEGIN(0) // set 'base_loop' as 0
* // The codes which wrapped by 'BEGIN(0) ~ END' scope
* // will not run when profiling mode is not enabled.
* auto poros_outputs = _poros_module->forward(poros_inputs);
* RUNTIME_PROFILE_LOOP_END
* // Run another inference beyond the scope of 'BEGIN ~ END'
* // to get valid outputs for subsequent tasks.
* auto poros_outputs = _poros_module->forward(poros_inputs);
* // do something .... will use 'poros_outputs' ...
* if (poros_outputs.isTensor()) {
* // ...
* }
* RUNTIME_PROFILE_LOOP_H2D_D2H_END
*
* @endcode In this case, 'poros_outputs' inside a function
* are wrapped by BEGIN and END, which may be required for
* subsequent tasks. So, we set 'base_loop' as 0 and lanuch
* another infer to get the valid outputs beyond the scope
* of 'BEGIN ~ END' for subsequent tasks.
*/
#define RUNTIME_PROFILE_LOOP_BEGIN(base_loop) \
__RUNTIME_PROFILE_LOOP_BEGIN(benchmark_option_, (base_loop))
#define RUNTIME_PROFILE_LOOP_END \
__RUNTIME_PROFILE_LOOP_END(benchmark_result_)
#define RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN \
__RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(benchmark_option_, 1)
#define RUNTIME_PROFILE_LOOP_H2D_D2H_END \
__RUNTIME_PROFILE_LOOP_H2D_D2H_END(benchmark_result_)
} // namespace fastdeploy

View File

@@ -13,23 +13,6 @@
// limitations under the License.
#include "fastdeploy/runtime/backends/lite/lite_backend.h"
// https://github.com/PaddlePaddle/Paddle-Lite/issues/8290
// When compiling the FastDeploy dynamic library, namely,
// WITH_STATIC_LIB=OFF, and depending on the Paddle Lite
// static library, you need to include the fake registration
// codes of Paddle Lite. When you compile the FastDeploy static
// library and depends on the Paddle Lite static library,
// WITH_STATIC_LIB=ON, you do not need to include the fake
// registration codes for Paddle Lite, but wait until you
// use the FastDeploy static library.
#if (defined(WITH_LITE_STATIC) && (!defined(WITH_STATIC_LIB)))
#warning You are compiling the FastDeploy dynamic library with \
Paddle Lite static lib We will automatically add some registration \
codes for ops, kernels and passes for Paddle Lite.
#include "paddle_use_kernels.h" // NOLINT
#include "paddle_use_ops.h" // NOLINT
#include "paddle_use_passes.h" // NOLINT
#endif
#include <cstring>
@@ -156,4 +139,5 @@ void LiteBackend::ConfigureNNAdapter(const LiteBackendOption& option) {
config_.set_nnadapter_dynamic_shape_info(option.nnadapter_dynamic_shape_info);
}
} // namespace fastdeploy

View File

@@ -100,7 +100,7 @@ bool LiteBackend::InitFromPaddle(const std::string& model_file,
auto shape = tensor->shape();
info.shape.assign(shape.begin(), shape.end());
info.name = output_names[i];
if (!option_.device == Device::KUNLUNXIN) {
if (option_.device != Device::KUNLUNXIN) {
info.dtype = LiteDataTypeToFD(tensor->precision());
}
outputs_desc_.emplace_back(info);
@@ -136,6 +136,8 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
<< inputs_desc_.size() << ")." << std::endl;
return false;
}
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
for (size_t i = 0; i < inputs.size(); ++i) {
auto iter = inputs_order_.find(inputs[i].name);
if (iter == inputs_order_.end()) {
@@ -143,6 +145,7 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
<< " in loaded model." << std::endl;
return false;
}
auto tensor = predictor_->GetInput(iter->second);
// Adjust dims only, allocate lazy.
tensor->Resize(inputs[i].shape);
@@ -175,7 +178,9 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
}
}
RUNTIME_PROFILE_LOOP_BEGIN(1)
predictor_->Run();
RUNTIME_PROFILE_LOOP_END
outputs->resize(outputs_desc_.size());
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
@@ -188,6 +193,7 @@ bool LiteBackend::Infer(std::vector<FDTensor>& inputs,
memcpy((*outputs)[i].MutableData(), tensor->data<void>(),
(*outputs)[i].Nbytes());
}
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true;
}

0
fastdeploy/runtime/backends/lite/lite_backend.h Executable file → Normal file
View File

View File

@@ -375,6 +375,7 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
return false;
}
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
for (size_t i = 0; i < inputs.size(); ++i) {
ov::Shape shape(inputs[i].shape.begin(), inputs[i].shape.end());
ov::Tensor ov_tensor(FDDataTypeToOV(inputs[i].dtype), shape,
@@ -382,7 +383,9 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
request_.set_tensor(inputs[i].name, ov_tensor);
}
RUNTIME_PROFILE_LOOP_BEGIN(1)
request_.infer();
RUNTIME_PROFILE_LOOP_END
outputs->resize(output_infos_.size());
for (size_t i = 0; i < output_infos_.size(); ++i) {
@@ -403,6 +406,7 @@ bool OpenVINOBackend::Infer(std::vector<FDTensor>& inputs,
out_tensor.data(), Device::CPU);
}
}
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true;
}

View File

@@ -13,9 +13,6 @@
// limitations under the License.
#include "fastdeploy/runtime/backends/ort/ort_backend.h"
#include <memory>
#include "fastdeploy/core/float16.h"
#include "fastdeploy/runtime/backends/ort/ops/adaptive_pool2d.h"
#include "fastdeploy/runtime/backends/ort/ops/multiclass_nms.h"
@@ -25,6 +22,9 @@
#include "paddle2onnx/converter.h"
#endif
#include <memory>
namespace fastdeploy {
std::vector<OrtCustomOp*> OrtBackend::custom_operators_ =
@@ -258,6 +258,7 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
}
// from FDTensor to Ort Inputs
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
for (size_t i = 0; i < inputs.size(); ++i) {
auto ort_value = CreateOrtValue(inputs[i], option_.device == Device::GPU);
binding_->BindInput(inputs[i].name.c_str(), ort_value);
@@ -270,12 +271,14 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
}
// Inference with inputs
RUNTIME_PROFILE_LOOP_BEGIN(1)
try {
session_.Run({}, *(binding_.get()));
} catch (const std::exception& e) {
FDERROR << "Failed to Infer: " << e.what() << std::endl;
return false;
}
RUNTIME_PROFILE_LOOP_END
// Convert result after inference
std::vector<Ort::Value> ort_outputs = binding_->GetOutputValues();
@@ -284,7 +287,7 @@ bool OrtBackend::Infer(std::vector<FDTensor>& inputs,
OrtValueToFDTensor(ort_outputs[i], &((*outputs)[i]), outputs_desc_[i].name,
copy_to_fd);
}
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true;
}

View File

@@ -222,12 +222,15 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
return false;
}
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
for (size_t i = 0; i < inputs.size(); ++i) {
auto handle = predictor_->GetInputHandle(inputs[i].name);
ShareTensorFromFDTensor(handle.get(), inputs[i]);
}
RUNTIME_PROFILE_LOOP_BEGIN(1)
predictor_->Run();
RUNTIME_PROFILE_LOOP_END
// output share backend memory only support CPU or GPU
if (option_.use_ipu) {
@@ -241,6 +244,7 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
}
PaddleTensorToFDTensor(handle, &((*outputs)[i]), copy_to_fd);
}
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true;
}

View File

@@ -287,14 +287,18 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
BuildTrtEngine();
}
RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN
cudaSetDevice(option_.gpu_id);
SetInputs(inputs);
AllocateOutputsBuffer(outputs, copy_to_fd);
RUNTIME_PROFILE_LOOP_BEGIN(1)
if (!context_->enqueueV2(bindings_.data(), stream_, nullptr)) {
FDERROR << "Failed to Infer with TensorRT." << std::endl;
return false;
}
RUNTIME_PROFILE_LOOP_END
for (size_t i = 0; i < outputs->size(); ++i) {
// if the final output tensor's dtype is different from the model output
// tensor's dtype, then we need cast the data to the final output's dtype
@@ -335,7 +339,7 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
FDASSERT(cudaStreamSynchronize(stream_) == cudaSuccess,
"[ERROR] Error occurs while sync cuda stream.");
}
RUNTIME_PROFILE_LOOP_H2D_D2H_END
return true;
}

View File

@@ -275,6 +275,8 @@ void Runtime::CreatePaddleBackend() {
#endif
backend_ = utils::make_unique<PaddleBackend>();
auto casted_backend = dynamic_cast<PaddleBackend*>(backend_.get());
casted_backend->benchmark_option_ = option.benchmark_option;
if (pd_option.model_from_memory_) {
FDASSERT(casted_backend->InitFromPaddle(option.model_file,
option.params_file, pd_option),
@@ -303,6 +305,7 @@ void Runtime::CreatePaddleBackend() {
void Runtime::CreateOpenVINOBackend() {
#ifdef ENABLE_OPENVINO_BACKEND
backend_ = utils::make_unique<OpenVINOBackend>();
backend_->benchmark_option_ = option.benchmark_option;
FDASSERT(backend_->Init(option), "Failed to initialize OpenVINOBackend.");
#else
FDASSERT(false,
@@ -316,6 +319,8 @@ void Runtime::CreateOpenVINOBackend() {
void Runtime::CreateOrtBackend() {
#ifdef ENABLE_ORT_BACKEND
backend_ = utils::make_unique<OrtBackend>();
backend_->benchmark_option_ = option.benchmark_option;
FDASSERT(backend_->Init(option), "Failed to initialize Backend::ORT.");
#else
FDASSERT(false,
@@ -351,6 +356,8 @@ void Runtime::CreateTrtBackend() {
trt_option.external_stream_ = option.external_stream_;
backend_ = utils::make_unique<TrtBackend>();
auto casted_backend = dynamic_cast<TrtBackend*>(backend_.get());
casted_backend->benchmark_option_ = option.benchmark_option;
if (option.model_format == ModelFormat::ONNX) {
if (option.model_from_memory_) {
FDASSERT(casted_backend->InitFromOnnx(option.model_file, trt_option),
@@ -403,6 +410,8 @@ void Runtime::CreateLiteBackend() {
"LiteBackend only support model format of ModelFormat::PADDLE");
backend_ = utils::make_unique<LiteBackend>();
auto casted_backend = dynamic_cast<LiteBackend*>(backend_.get());
casted_backend->benchmark_option_ = option.benchmark_option;
FDASSERT(casted_backend->InitFromPaddle(option.model_file, option.params_file,
option.paddle_lite_option),
"Load model from nb file failed while initializing LiteBackend.");

View File

@@ -95,6 +95,11 @@ struct FASTDEPLOY_DECL Runtime {
*/
bool Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
const RuntimeOption& _option);
/** \brief Get profile time of Runtime after the profile process is done.
*/
double GetProfileTime() {
return backend_->benchmark_result_.time_of_runtime;
}
private:
void CreateOrtBackend();

View File

@@ -32,6 +32,7 @@
#include "fastdeploy/runtime/backends/rknpu2/option.h"
#include "fastdeploy/runtime/backends/sophgo/option.h"
#include "fastdeploy/runtime/backends/tensorrt/option.h"
#include "fastdeploy/benchmark/option.h"
namespace fastdeploy {
@@ -347,6 +348,26 @@ struct FASTDEPLOY_DECL RuntimeOption {
float available_memory_proportion = 1.0,
bool enable_half_partial = false);
/** \brief Set the profile mode as 'true'.
*
* \param[in] inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime.
* \param[in] repeat Repeat times for runtime inference.
* \param[in] warmup Warmup times for runtime inference.
*/
void EnableProfiling(bool inclue_h2d_d2h = false,
int repeat = 100, int warmup = 50) {
benchmark_option.enable_profile = true;
benchmark_option.warmup = warmup;
benchmark_option.repeats = repeat;
benchmark_option.include_h2d_d2h = inclue_h2d_d2h;
}
/** \brief Set the profile mode as 'false'.
*/
void DisableProfiling() {
benchmark_option.enable_profile = false;
}
Backend backend = Backend::UNKNOWN;
// for cpu inference
@@ -419,6 +440,9 @@ struct FASTDEPLOY_DECL RuntimeOption {
bool model_from_memory_ = false;
// format of input model
ModelFormat model_format = ModelFormat::PADDLE;
// Benchmark option
benchmark::BenchmarkOption benchmark_option;
};
} // namespace fastdeploy

View File

@@ -54,6 +54,11 @@ class FastDeployModel:
def print_statis_info_of_runtime(self):
return self._model.print_statis_info_of_runtime()
def get_profile_time(self):
"""Get profile time of Runtime after the profile process is done.
"""
return self._model.get_profile_time()
@property
def runtime_option(self):
return self._model.runtime_option if self._model is not None else None

View File

@@ -144,6 +144,11 @@ class Runtime:
index, self.num_outputs)
return self._runtime.get_output_info(index)
def get_profile_time(self):
"""Get profile time of Runtime after the profile process is done.
"""
return self._runtime.get_profile_time()
class RuntimeOption:
"""Options for FastDeploy Runtime.
@@ -552,6 +557,21 @@ class RuntimeOption:
available_memory_proportion,
enable_half_partial)
def enable_profiling(self,
inclue_h2d_d2h=False,
repeat=100, warmup=50):
"""Set the profile mode as 'true'.
:param inclue_h2d_d2h Whether to include time of H2D_D2H for time of runtime.
:param repeat Repeat times for runtime inference.
:param warmup Warmup times for runtime inference.
"""
return self._option.enable_profiling(inclue_h2d_d2h, repeat, warmup)
def disable_profiling(self):
"""Set the profile mode as 'false'.
"""
return self._option.disable_profiling()
def __repr__(self):
attrs = dir(self._option)
message = "RuntimeOption(\n"

View File

@@ -73,6 +73,7 @@ setup_configs["ENABLE_VISION"] = os.getenv("ENABLE_VISION", "OFF")
setup_configs["ENABLE_ENCRYPTION"] = os.getenv("ENABLE_ENCRYPTION", "OFF")
setup_configs["ENABLE_FLYCV"] = os.getenv("ENABLE_FLYCV", "OFF")
setup_configs["ENABLE_TEXT"] = os.getenv("ENABLE_TEXT", "OFF")
setup_configs["ENABLE_BENCHMARK"] = os.getenv("ENABLE_BENCHMARK", "OFF")
setup_configs["WITH_GPU"] = os.getenv("WITH_GPU", "OFF")
setup_configs["WITH_IPU"] = os.getenv("WITH_IPU", "OFF")
setup_configs["WITH_KUNLUNXIN"] = os.getenv("WITH_KUNLUNXIN", "OFF")

View File

@@ -0,0 +1,80 @@
#!/bin/bash
set -e
set +x
# -------------------------------------------------------------------------------
# readonly global variables
# -------------------------------------------------------------------------------
readonly ROOT_PATH=$(pwd)
readonly BUILD_ROOT=build/Linux
readonly BUILD_DIR=${BUILD_ROOT}/x86_64
# -------------------------------------------------------------------------------
# tasks
# -------------------------------------------------------------------------------
__make_build_dir() {
if [ ! -d "${BUILD_DIR}" ]; then
echo "-- [INFO] BUILD_DIR: ${BUILD_DIR} not exists, setup manually ..."
if [ ! -d "${BUILD_ROOT}" ]; then
mkdir -p "${BUILD_ROOT}" && echo "-- [INFO] Created ${BUILD_ROOT} !"
fi
mkdir -p "${BUILD_DIR}" && echo "-- [INFO] Created ${BUILD_DIR} !"
else
echo "-- [INFO] Found BUILD_DIR: ${BUILD_DIR}"
fi
}
__check_cxx_envs() {
if [ $LDFLAGS ]; then
echo "-- [INFO] Found LDFLAGS: ${LDFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset LDFLAGS
fi
if [ $CPPFLAGS ]; then
echo "-- [INFO] Found CPPFLAGS: ${CPPFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPPFLAGS
fi
if [ $CPLUS_INCLUDE_PATH ]; then
echo "-- [INFO] Found CPLUS_INCLUDE_PATH: ${CPLUS_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPLUS_INCLUDE_PATH
fi
if [ $C_INCLUDE_PATH ]; then
echo "-- [INFO] Found C_INCLUDE_PATH: ${C_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset C_INCLUDE_PATH
fi
}
__build_fastdeploy_linux_x86_64_shared() {
local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install"
cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}"
cmake -DCMAKE_BUILD_TYPE=Release \
-DWITH_GPU=OFF \
-DENABLE_ORT_BACKEND=ON \
-DENABLE_PADDLE_BACKEND=ON \
-DENABLE_OPENVINO_BACKEND=ON \
-DENABLE_PADDLE2ONNX=ON \
-DENABLE_VISION=ON \
-DENABLE_BENCHMARK=ON \
-DBUILD_EXAMPLES=ON \
-DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \
-Wno-dev ../../.. && make -j8 && make install
echo "-- [INFO][built][x86_64]][${BUILD_DIR}/install]"
}
main() {
__make_build_dir
__check_cxx_envs
__build_fastdeploy_linux_x86_64_shared
exit 0
}
main
# Usage:
# ./scripts/linux/build_linux_x86_64_cpp_cpu.sh

View File

@@ -0,0 +1,83 @@
#!/bin/bash
set -e
set +x
# -------------------------------------------------------------------------------
# readonly global variables
# -------------------------------------------------------------------------------
readonly ROOT_PATH=$(pwd)
readonly BUILD_ROOT=build/Linux
readonly BUILD_DIR="${BUILD_ROOT}/x86_64_gpu"
# -------------------------------------------------------------------------------
# tasks
# -------------------------------------------------------------------------------
__make_build_dir() {
if [ ! -d "${BUILD_DIR}" ]; then
echo "-- [INFO] BUILD_DIR: ${BUILD_DIR} not exists, setup manually ..."
if [ ! -d "${BUILD_ROOT}" ]; then
mkdir -p "${BUILD_ROOT}" && echo "-- [INFO] Created ${BUILD_ROOT} !"
fi
mkdir -p "${BUILD_DIR}" && echo "-- [INFO] Created ${BUILD_DIR} !"
else
echo "-- [INFO] Found BUILD_DIR: ${BUILD_DIR}"
fi
}
__check_cxx_envs() {
if [ $LDFLAGS ]; then
echo "-- [INFO] Found LDFLAGS: ${LDFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset LDFLAGS
fi
if [ $CPPFLAGS ]; then
echo "-- [INFO] Found CPPFLAGS: ${CPPFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPPFLAGS
fi
if [ $CPLUS_INCLUDE_PATH ]; then
echo "-- [INFO] Found CPLUS_INCLUDE_PATH: ${CPLUS_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPLUS_INCLUDE_PATH
fi
if [ $C_INCLUDE_PATH ]; then
echo "-- [INFO] Found C_INCLUDE_PATH: ${C_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset C_INCLUDE_PATH
fi
}
__build_fastdeploy_linux_x86_64_gpu_shared() {
local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install"
cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}"
cmake -DCMAKE_BUILD_TYPE=Release \
-DWITH_GPU=ON \
-DTRT_DIRECTORY=${TRT_DIRECTORY} \
-DCUDA_DIRECTORY=${CUDA_DIRECTORY} \
-DENABLE_ORT_BACKEND=ON \
-DENABLE_TRT_BACKEND=ON \
-DENABLE_PADDLE_BACKEND=ON \
-DENABLE_OPENVINO_BACKEND=ON \
-DENABLE_PADDLE2ONNX=ON \
-DENABLE_VISION=ON \
-DENABLE_BENCHMARK=ON \
-DBUILD_EXAMPLES=ON \
-DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \
-Wno-dev ../../.. && make -j8 && make install
echo "-- [INFO][built][x86_64_gpu}][${BUILD_DIR}/install]"
}
main() {
__make_build_dir
__check_cxx_envs
__build_fastdeploy_linux_x86_64_gpu_shared
exit 0
}
main
# Usage:
# ./scripts/linux/build_linux_x86_64_cpp_gpu.sh

View File

@@ -0,0 +1,102 @@
#!/bin/bash
set -e
set +x
# -------------------------------------------------------------------------------
# readonly global variables
# -------------------------------------------------------------------------------
readonly ROOT_PATH=$(pwd)
readonly BUILD_ROOT=build/MacOSX
readonly OSX_ARCH=$1 # arm64, x86_64
readonly BUILD_DIR=${BUILD_ROOT}/${OSX_ARCH}
# -------------------------------------------------------------------------------
# tasks
# -------------------------------------------------------------------------------
__make_build_dir() {
if [ ! -d "${BUILD_DIR}" ]; then
echo "-- [INFO] BUILD_DIR: ${BUILD_DIR} not exists, setup manually ..."
if [ ! -d "${BUILD_ROOT}" ]; then
mkdir -p "${BUILD_ROOT}" && echo "-- [INFO] Created ${BUILD_ROOT} !"
fi
mkdir -p "${BUILD_DIR}" && echo "-- [INFO] Created ${BUILD_DIR} !"
else
echo "-- [INFO] Found BUILD_DIR: ${BUILD_DIR}"
fi
}
__check_cxx_envs() {
if [ $LDFLAGS ]; then
echo "-- [INFO] Found LDFLAGS: ${LDFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset LDFLAGS
fi
if [ $CPPFLAGS ]; then
echo "-- [INFO] Found CPPFLAGS: ${CPPFLAGS}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPPFLAGS
fi
if [ $CPLUS_INCLUDE_PATH ]; then
echo "-- [INFO] Found CPLUS_INCLUDE_PATH: ${CPLUS_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset CPLUS_INCLUDE_PATH
fi
if [ $C_INCLUDE_PATH ]; then
echo "-- [INFO] Found C_INCLUDE_PATH: ${C_INCLUDE_PATH}, \c"
echo "unset it before crossing compiling ${BUILD_DIR}"
unset C_INCLUDE_PATH
fi
}
__build_fastdeploy_osx_arm64_shared() {
local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install"
cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}"
cmake -DCMAKE_BUILD_TYPE=MinSizeRel \
-DENABLE_ORT_BACKEND=ON \
-DENABLE_PADDLE2ONNX=ON \
-DENABLE_VISION=ON \
-DENABLE_BENCHMARK=ON \
-DBUILD_EXAMPLES=ON \
-DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \
-Wno-dev ../../.. && make -j8 && make install
echo "-- [INFO][built][${OSX_ARCH}][${BUILD_DIR}/install]"
}
__build_fastdeploy_osx_x86_64_shared() {
local FASDEPLOY_INSTALL_DIR="${ROOT_PATH}/${BUILD_DIR}/install"
cd "${BUILD_DIR}" && echo "-- [INFO] Working Dir: ${PWD}"
cmake -DCMAKE_BUILD_TYPE=MinSizeRel \
-DENABLE_ORT_BACKEND=ON \
-DENABLE_PADDLE_BACKEND=ON \
-DENABLE_OPENVINO_BACKEND=ON \
-DENABLE_PADDLE2ONNX=ON \
-DENABLE_VISION=ON \
-DENABLE_BENCHMARK=ON \
-DBUILD_EXAMPLES=ON \
-DCMAKE_INSTALL_PREFIX=${FASDEPLOY_INSTALL_DIR} \
-Wno-dev ../../.. && make -j8 && make install
echo "-- [INFO][built][${OSX_ARCH}][${BUILD_DIR}/install]"
}
main() {
__make_build_dir
__check_cxx_envs
if [ "$OSX_ARCH" = "arm64" ]; then
__build_fastdeploy_osx_arm64_shared
else
__build_fastdeploy_osx_x86_64_shared
fi
exit 0
}
main
# Usage:
# ./scripts/macosx/build_macosx_cpp.sh arm64
# ./scripts/macosx/build_macosx_cpp.sh x86_64