[Benchmark]Benchmark cpp for YOLOv5 (#1224)

* add GPL lisence

* add GPL-3.0 lisence

* add GPL-3.0 lisence

* add GPL-3.0 lisence

* support yolov8

* add pybind for yolov8

* add yolov8 readme

* add cpp benchmark

* add cpu and gpu mem

* public part split

* add runtime mode

* fixed bugs

* add cpu_thread_nums

* deal with comments

* deal with comments

* deal with comments

* rm useless code

* add FASTDEPLOY_DECL

* add FASTDEPLOY_DECL
This commit is contained in:
WJJ1995
2023-02-07 21:26:04 +08:00
committed by GitHub
parent e90e1ff435
commit c487359e33
27 changed files with 422 additions and 44 deletions

17
benchmark/cpp/CMakeLists.txt Executable file
View File

@@ -0,0 +1,17 @@
PROJECT(infer_demo C CXX)
CMAKE_MINIMUM_REQUIRED (VERSION 3.10)
# specify the decompress directory of FastDeploy SDK
option(FASTDEPLOY_INSTALL_DIR "Path of downloaded fastdeploy sdk.")
include(${FASTDEPLOY_INSTALL_DIR}/utils/gflags.cmake)
include(${FASTDEPLOY_INSTALL_DIR}/FastDeploy.cmake)
include_directories(${FASTDEPLOY_INCS})
add_executable(benchmark_yolov5 ${PROJECT_SOURCE_DIR}/benchmark_yolov5.cc)
if(UNIX AND (NOT APPLE) AND (NOT ANDROID))
target_link_libraries(benchmark_yolov5 ${FASTDEPLOY_LIBS} gflags pthread)
else()
target_link_libraries(benchmark_yolov5 ${FASTDEPLOY_LIBS} gflags)
endif()

110
benchmark/cpp/benchmark_yolov5.cc Executable file
View File

@@ -0,0 +1,110 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/benchmark/utils.h"
#include "fastdeploy/vision.h"
#include "flags.h"
bool RunModel(std::string model_file, std::string image_file, size_t warmup,
size_t repeats, size_t dump_period, std::string cpu_mem_file_name,
std::string gpu_mem_file_name) {
// Initialization
auto option = fastdeploy::RuntimeOption();
if (!CreateRuntimeOption(&option)) {
PrintUsage();
return false;
}
if (FLAGS_profile_mode == "runtime") {
option.EnableProfiling(FLAGS_include_h2d_d2h, repeats, warmup);
}
auto model = fastdeploy::vision::detection::YOLOv5(model_file, "", option);
if (!model.Initialized()) {
std::cerr << "Failed to initialize." << std::endl;
return false;
}
auto im = cv::imread(image_file);
// For Runtime
if (FLAGS_profile_mode == "runtime") {
fastdeploy::vision::DetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return false;
}
double profile_time = model.GetProfileTime() * 1000;
std::cout << "Runtime(ms): " << profile_time << "ms." << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
} else {
// For End2End
// Step1: warm up for warmup times
std::cout << "Warmup " << warmup << " times..." << std::endl;
for (int i = 0; i < warmup; i++) {
fastdeploy::vision::DetectionResult res;
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return false;
}
}
std::vector<float> end2end_statis;
// Step2: repeat for repeats times
std::cout << "Counting time..." << std::endl;
fastdeploy::TimeCounter tc;
fastdeploy::vision::DetectionResult res;
for (int i = 0; i < repeats; i++) {
if (FLAGS_collect_memory_info && i % dump_period == 0) {
fastdeploy::benchmark::DumpCurrentCpuMemoryUsage(cpu_mem_file_name);
fastdeploy::benchmark::DumpCurrentGpuMemoryUsage(gpu_mem_file_name,
FLAGS_device_id);
}
tc.Start();
if (!model.Predict(im, &res)) {
std::cerr << "Failed to predict." << std::endl;
return false;
}
tc.End();
end2end_statis.push_back(tc.Duration() * 1000);
}
float end2end = std::accumulate(end2end_statis.end() - repeats,
end2end_statis.end(), 0.f) /
repeats;
std::cout << "End2End(ms): " << end2end << "ms." << std::endl;
auto vis_im = fastdeploy::vision::VisDetection(im, res);
cv::imwrite("vis_result.jpg", vis_im);
std::cout << "Visualized result saved in ./vis_result.jpg" << std::endl;
}
return true;
}
int main(int argc, char* argv[]) {
google::ParseCommandLineFlags(&argc, &argv, true);
int repeats = FLAGS_repeat;
int warmup = FLAGS_warmup;
int dump_period = FLAGS_dump_period;
std::string cpu_mem_file_name = "result_cpu.txt";
std::string gpu_mem_file_name = "result_gpu.txt";
// Run model
if (RunModel(FLAGS_model, FLAGS_image, warmup, repeats, dump_period,
cpu_mem_file_name, gpu_mem_file_name) != true) {
exit(1);
}
if (FLAGS_collect_memory_info) {
float cpu_mem = fastdeploy::benchmark::GetCpuMemoryUsage(cpu_mem_file_name);
float gpu_mem = fastdeploy::benchmark::GetGpuMemoryUsage(gpu_mem_file_name);
std::cout << "cpu_rss_mb: " << cpu_mem << "MB." << std::endl;
std::cout << "gpu_rss_mb: " << gpu_mem << "MB." << std::endl;
}
return 0;
}

99
benchmark/cpp/flags.h Executable file
View File

@@ -0,0 +1,99 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "gflags/gflags.h"
#include "fastdeploy/utils/perf.h"
DEFINE_string(model, "", "Directory of the inference model.");
DEFINE_string(image, "", "Path of the image file.");
DEFINE_string(device, "cpu",
"Type of inference device, support 'cpu' or 'gpu'.");
DEFINE_int32(device_id, 0, "device(gpu) id.");
DEFINE_int32(warmup, 200, "Number of warmup for profiling.");
DEFINE_int32(repeat, 1000, "Number of repeats for profiling.");
DEFINE_string(profile_mode, "runtime", "runtime or end2end.");
DEFINE_string(backend, "default",
"The inference runtime backend, support: ['default', 'ort', "
"'paddle', 'ov', 'trt', 'paddle_trt']");
DEFINE_int32(cpu_thread_nums, 8, "Set numbers of cpu thread.");
DEFINE_bool(
include_h2d_d2h, false, "Whether run profiling with h2d and d2h.");
DEFINE_bool(
use_fp16, false,
"Whether to use FP16 mode, only support 'trt' and 'paddle_trt' backend");
DEFINE_bool(
collect_memory_info, false, "Whether to collect memory info");
DEFINE_int32(dump_period, 100, "How often to collect memory info.");
void PrintUsage() {
std::cout << "Usage: infer_demo --model model_path --image img_path --device "
"[cpu|gpu] --backend "
"[default|ort|paddle|ov|trt|paddle_trt] "
"--use_fp16 false"
<< std::endl;
std::cout << "Default value of device: cpu" << std::endl;
std::cout << "Default value of backend: default" << std::endl;
std::cout << "Default value of use_fp16: false" << std::endl;
}
bool CreateRuntimeOption(fastdeploy::RuntimeOption* option) {
if (FLAGS_device == "gpu") {
option->UseGpu();
if (FLAGS_backend == "ort") {
option->UseOrtBackend();
} else if (FLAGS_backend == "paddle") {
option->UsePaddleInferBackend();
} else if (FLAGS_backend == "trt" || FLAGS_backend == "paddle_trt") {
option->UseTrtBackend();
option->SetTrtInputShape("input", {1, 3, 112, 112});
if (FLAGS_backend == "paddle_trt") {
option->EnablePaddleToTrt();
}
if (FLAGS_use_fp16) {
option->EnableTrtFP16();
}
} else if (FLAGS_backend == "default") {
return true;
} else {
std::cout << "While inference with GPU, only support "
"default/ort/paddle/trt/paddle_trt now, "
<< FLAGS_backend << " is not supported." << std::endl;
return false;
}
} else if (FLAGS_device == "cpu") {
option->SetCpuThreadNum(FLAGS_cpu_thread_nums);
if (FLAGS_backend == "ort") {
option->UseOrtBackend();
} else if (FLAGS_backend == "ov") {
option->UseOpenVINOBackend();
} else if (FLAGS_backend == "paddle") {
option->UsePaddleInferBackend();
} else if (FLAGS_backend == "default") {
return true;
} else {
std::cout << "While inference with CPU, only support "
"default/ort/ov/paddle now, "
<< FLAGS_backend << " is not supported." << std::endl;
return false;
}
} else {
std::cerr << "Only support device CPU/GPU now, " << FLAGS_device
<< " is not supported." << std::endl;
return false;
}
return true;
}

View File

@@ -17,7 +17,8 @@ import cv2
import os import os
import numpy as np import numpy as np
import time import time
from tqdm import tqdm from tqdm import tqdm
def parse_arguments(): def parse_arguments():
import argparse import argparse
@@ -38,19 +39,19 @@ def parse_arguments():
"--profile_mode", "--profile_mode",
type=str, type=str,
default="runtime", default="runtime",
help="runtime or end2end.") help="runtime or end2end.")
parser.add_argument( parser.add_argument(
"--repeat", "--repeat",
required=True, required=True,
type=int, type=int,
default=1000, default=1000,
help="number of repeats for profiling.") help="number of repeats for profiling.")
parser.add_argument( parser.add_argument(
"--warmup", "--warmup",
required=True, required=True,
type=int, type=int,
default=50, default=50,
help="number of warmup for profiling.") help="number of warmup for profiling.")
parser.add_argument( parser.add_argument(
"--device", "--device",
default="cpu", default="cpu",
@@ -74,7 +75,7 @@ def parse_arguments():
"--include_h2d_d2h", "--include_h2d_d2h",
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="whether run profiling with h2d and d2h") help="whether run profiling with h2d and d2h")
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -85,7 +86,7 @@ def build_option(args):
backend = args.backend backend = args.backend
enable_trt_fp16 = args.enable_trt_fp16 enable_trt_fp16 = args.enable_trt_fp16
if args.profile_mode == "runtime": if args.profile_mode == "runtime":
option.enable_profiling(args.include_h2d_d2h, args.repeat, args.warmup) option.enable_profiling(args.include_h2d_d2h, args.repeat, args.warmup)
option.set_cpu_thread_num(args.cpu_num_thread) option.set_cpu_thread_num(args.cpu_num_thread)
if device == "gpu": if device == "gpu":
option.use_gpu() option.use_gpu()
@@ -274,25 +275,27 @@ if __name__ == '__main__':
enable_gpu = args.device == "gpu" enable_gpu = args.device == "gpu"
monitor = Monitor(enable_gpu, gpu_id) monitor = Monitor(enable_gpu, gpu_id)
monitor.start() monitor.start()
im_ori = cv2.imread(args.image) im_ori = cv2.imread(args.image)
if args.profile_mode == "runtime": if args.profile_mode == "runtime":
result = model.predict(im_ori) result = model.predict(im_ori)
profile_time = model.get_profile_time() profile_time = model.get_profile_time()
dump_result["runtime"] = profile_time * 1000 dump_result["runtime"] = profile_time * 1000
f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) f.writelines("Runtime(ms): {} \n".format(
str(dump_result["runtime"])))
print("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) print("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
else: else:
# end2end # end2end
for i in range(args.warmup): for i in range(args.warmup):
result = model.predict(im_ori) result = model.predict(im_ori)
start = time.time() start = time.time()
for i in tqdm(range(args.repeat)): for i in tqdm(range(args.repeat)):
result = model.predict(im_ori) result = model.predict(im_ori)
end = time.time() end = time.time()
dump_result["end2end"] = ((end - start) / args.repeat) * 1000.0 dump_result["end2end"] = ((end - start) / args.repeat) * 1000.0
f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"]))) f.writelines("End2End(ms): {} \n".format(
str(dump_result["end2end"])))
print("End2End(ms): {} \n".format(str(dump_result["end2end"]))) print("End2End(ms): {} \n".format(str(dump_result["end2end"])))
if enable_collect_memory_info: if enable_collect_memory_info:
@@ -304,7 +307,7 @@ if __name__ == '__main__':
'memory.used'] if 'gpu' in mem_info else 0 'memory.used'] if 'gpu' in mem_info else 0
dump_result["gpu_util"] = mem_info['gpu'][ dump_result["gpu_util"] = mem_info['gpu'][
'utilization.gpu'] if 'gpu' in mem_info else 0 'utilization.gpu'] if 'gpu' in mem_info else 0
if enable_collect_memory_info: if enable_collect_memory_info:
f.writelines("cpu_rss_mb: {} \n".format( f.writelines("cpu_rss_mb: {} \n".format(
str(dump_result["cpu_rss_mb"]))) str(dump_result["cpu_rss_mb"])))

View File

@@ -17,9 +17,9 @@ import cv2
import os import os
import numpy as np import numpy as np
import time import time
from sympy import EX
from tqdm import tqdm from tqdm import tqdm
def parse_arguments(): def parse_arguments():
import argparse import argparse
import ast import ast
@@ -39,19 +39,19 @@ def parse_arguments():
"--profile_mode", "--profile_mode",
type=str, type=str,
default="runtime", default="runtime",
help="runtime or end2end.") help="runtime or end2end.")
parser.add_argument( parser.add_argument(
"--repeat", "--repeat",
required=True, required=True,
type=int, type=int,
default=1000, default=1000,
help="number of repeats for profiling.") help="number of repeats for profiling.")
parser.add_argument( parser.add_argument(
"--warmup", "--warmup",
required=True, required=True,
type=int, type=int,
default=50, default=50,
help="number of warmup for profiling.") help="number of warmup for profiling.")
parser.add_argument( parser.add_argument(
"--device", "--device",
default="cpu", default="cpu",
@@ -70,7 +70,7 @@ def parse_arguments():
"--enable_lite_fp16", "--enable_lite_fp16",
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="whether enable fp16 in Paddle Lite backend") help="whether enable fp16 in Paddle Lite backend")
parser.add_argument( parser.add_argument(
"--enable_collect_memory_info", "--enable_collect_memory_info",
type=ast.literal_eval, type=ast.literal_eval,
@@ -80,7 +80,7 @@ def parse_arguments():
"--include_h2d_d2h", "--include_h2d_d2h",
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="whether run profiling with h2d and d2h") help="whether run profiling with h2d and d2h")
args = parser.parse_args() args = parser.parse_args()
return args return args
@@ -92,7 +92,7 @@ def build_option(args):
enable_trt_fp16 = args.enable_trt_fp16 enable_trt_fp16 = args.enable_trt_fp16
enable_lite_fp16 = args.enable_lite_fp16 enable_lite_fp16 = args.enable_lite_fp16
if args.profile_mode == "runtime": if args.profile_mode == "runtime":
option.enable_profiling(args.include_h2d_d2h, args.repeat, args.warmup) option.enable_profiling(args.include_h2d_d2h, args.repeat, args.warmup)
option.set_cpu_thread_num(args.cpu_num_thread) option.set_cpu_thread_num(args.cpu_num_thread)
if device == "gpu": if device == "gpu":
option.use_gpu() option.use_gpu()
@@ -149,7 +149,7 @@ def build_option(args):
else: else:
raise Exception( raise Exception(
"While inference with CPU, only support default/ort/lite/paddle now, {} is not supported.". "While inference with CPU, only support default/ort/lite/paddle now, {} is not supported.".
format(backend)) format(backend))
elif device == "ascend": elif device == "ascend":
option.use_ascend() option.use_ascend()
if backend == "lite": if backend == "lite":
@@ -161,11 +161,11 @@ def build_option(args):
else: else:
raise Exception( raise Exception(
"While inference with CPU, only support default/lite now, {} is not supported.". "While inference with CPU, only support default/lite now, {} is not supported.".
format(backend)) format(backend))
else: else:
raise Exception( raise Exception(
"Only support device CPU/GPU/Kunlunxin/Ascend now, {} is not supported.".format( "Only support device CPU/GPU/Kunlunxin/Ascend now, {} is not supported.".
device)) format(device))
return option return option
@@ -340,19 +340,21 @@ if __name__ == '__main__':
result = model.predict(im_ori) result = model.predict(im_ori)
profile_time = model.get_profile_time() profile_time = model.get_profile_time()
dump_result["runtime"] = profile_time * 1000 dump_result["runtime"] = profile_time * 1000
f.writelines("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) f.writelines("Runtime(ms): {} \n".format(
str(dump_result["runtime"])))
print("Runtime(ms): {} \n".format(str(dump_result["runtime"]))) print("Runtime(ms): {} \n".format(str(dump_result["runtime"])))
else: else:
# end2end # end2end
for i in range(args.warmup): for i in range(args.warmup):
result = model.predict(im_ori) result = model.predict(im_ori)
start = time.time() start = time.time()
for i in tqdm(range(args.repeat)): for i in tqdm(range(args.repeat)):
result = model.predict(im_ori) result = model.predict(im_ori)
end = time.time() end = time.time()
dump_result["end2end"] = ((end - start) / args.repeat) * 1000.0 dump_result["end2end"] = ((end - start) / args.repeat) * 1000.0
f.writelines("End2End(ms): {} \n".format(str(dump_result["end2end"]))) f.writelines("End2End(ms): {} \n".format(
str(dump_result["end2end"])))
print("End2End(ms): {} \n".format(str(dump_result["end2end"]))) print("End2End(ms): {} \n".format(str(dump_result["end2end"])))
if enable_collect_memory_info: if enable_collect_memory_info:
@@ -364,7 +366,7 @@ if __name__ == '__main__':
'memory.used'] if 'gpu' in mem_info else 0 'memory.used'] if 'gpu' in mem_info else 0
dump_result["gpu_util"] = mem_info['gpu'][ dump_result["gpu_util"] = mem_info['gpu'][
'utilization.gpu'] if 'gpu' in mem_info else 0 'utilization.gpu'] if 'gpu' in mem_info else 0
if enable_collect_memory_info: if enable_collect_memory_info:
f.writelines("cpu_rss_mb: {} \n".format( f.writelines("cpu_rss_mb: {} \n".format(
str(dump_result["cpu_rss_mb"]))) str(dump_result["cpu_rss_mb"])))

10
fastdeploy/benchmark/benchmark.h Normal file → Executable file
View File

@@ -18,7 +18,7 @@
#include "fastdeploy/benchmark/option.h" #include "fastdeploy/benchmark/option.h"
#include "fastdeploy/benchmark/results.h" #include "fastdeploy/benchmark/results.h"
#ifdef ENABLE_BENCHMARK #ifdef ENABLE_BENCHMARK
#define __RUNTIME_PROFILE_LOOP_BEGIN(option, base_loop) \ #define __RUNTIME_PROFILE_LOOP_BEGIN(option, base_loop) \
int __p_loop = (base_loop); \ int __p_loop = (base_loop); \
const bool __p_enable_profile = option.enable_profile; \ const bool __p_enable_profile = option.enable_profile; \
@@ -75,12 +75,12 @@
result.time_of_runtime = \ result.time_of_runtime = \
__p_tc_duration_h / static_cast<double>(__p_repeats_h); \ __p_tc_duration_h / static_cast<double>(__p_repeats_h); \
} \ } \
} }
#else #else
#define __RUNTIME_PROFILE_LOOP_BEGIN(option, base_loop) \ #define __RUNTIME_PROFILE_LOOP_BEGIN(option, base_loop) \
for (int __p_i = 0; __p_i < (base_loop); ++ __p_i) { for (int __p_i = 0; __p_i < (base_loop); ++__p_i) {
#define __RUNTIME_PROFILE_LOOP_END(result) } #define __RUNTIME_PROFILE_LOOP_END(result) }
#define __RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(option, base_loop) \ #define __RUNTIME_PROFILE_LOOP_H2D_D2H_BEGIN(option, base_loop) \
for (int __p_i_h = 0; __p_i_h < (base_loop); ++ __p_i_h) { for (int __p_i_h = 0; __p_i_h < (base_loop); ++__p_i_h) {
#define __RUNTIME_PROFILE_LOOP_H2D_D2H_END(result) } #define __RUNTIME_PROFILE_LOOP_H2D_D2H_END(result) }
#endif #endif

26
fastdeploy/benchmark/option.h Normal file → Executable file
View File

@@ -26,22 +26,22 @@ struct BenchmarkOption {
int warmup = 50; ///< Warmup for backend inference. int warmup = 50; ///< Warmup for backend inference.
int repeats = 100; ///< Repeats for backend inference. int repeats = 100; ///< Repeats for backend inference.
bool enable_profile = false; ///< Whether to use profile or not. bool enable_profile = false; ///< Whether to use profile or not.
bool include_h2d_d2h = false; ///< Whether to include time of H2D_D2H for time of runtime. bool include_h2d_d2h = false; ///< Whether to include time of H2D_D2H for time of runtime. // NOLINT
friend std::ostream& operator<<( friend std::ostream& operator<<(
std::ostream& output, const BenchmarkOption &option) { std::ostream& output, const BenchmarkOption &option) {
if (!option.include_h2d_d2h) { if (!option.include_h2d_d2h) {
output << "Running profiling for Runtime " output << "Running profiling for Runtime "
<< "without H2D and D2H, "; << "without H2D and D2H, ";
} else { } else {
output << "Running profiling for Runtime " output << "Running profiling for Runtime "
<< "with H2D and D2H, "; << "with H2D and D2H, ";
} }
output << "Repeats: " << option.repeats << ", " output << "Repeats: " << option.repeats << ", "
<< "Warmup: " << option.warmup; << "Warmup: " << option.warmup;
return output; return output;
} }
}; };
} // namespace benchmark } // namespace benchmark
} // namespace fastdeploy } // namespace fastdeploy

93
fastdeploy/benchmark/utils.cc Executable file
View File

@@ -0,0 +1,93 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "fastdeploy/benchmark/utils.h"
namespace fastdeploy {
namespace benchmark {
void DumpCurrentCpuMemoryUsage(const std::string& name) {
int iPid = static_cast<int>(getpid());
std::string command = "pmap -x " + std::to_string(iPid) + " | grep total";
FILE* pp = popen(command.data(), "r");
if (!pp) return;
char tmp[1024];
while (fgets(tmp, sizeof(tmp), pp) != NULL) {
std::ofstream write;
write.open(name, std::ios::app);
write << tmp;
write.close();
}
pclose(pp);
return;
}
void DumpCurrentGpuMemoryUsage(const std::string& name, int device_id) {
std::string command = "nvidia-smi --id=" + std::to_string(device_id) +
" --query-gpu=index,uuid,name,timestamp,memory.total,"
"memory.free,memory.used,utilization.gpu,utilization."
"memory --format=csv,noheader,nounits";
FILE* pp = popen(command.data(), "r");
if (!pp) return;
char tmp[1024];
while (fgets(tmp, sizeof(tmp), pp) != NULL) {
std::ofstream write;
write.open(name, std::ios::app);
write << tmp;
write.close();
}
pclose(pp);
return;
}
float GetCpuMemoryUsage(const std::string& name) {
std::ifstream read(name);
std::string line;
float max_cpu_mem = -1;
while (getline(read, line)) {
std::stringstream ss(line);
std::string tmp;
std::vector<std::string> nums;
while (getline(ss, tmp, ' ')) {
tmp = strip(tmp);
if (tmp.empty()) continue;
nums.push_back(tmp);
}
max_cpu_mem = std::max(max_cpu_mem, stof(nums[3]));
}
return max_cpu_mem / 1024;
}
float GetGpuMemoryUsage(const std::string& name) {
std::ifstream read(name);
std::string line;
float max_gpu_mem = -1;
while (getline(read, line)) {
std::stringstream ss(line);
std::string tmp;
std::vector<std::string> nums;
while (getline(ss, tmp, ',')) {
tmp = strip(tmp);
if (tmp.empty()) continue;
nums.push_back(tmp);
}
max_gpu_mem = std::max(max_gpu_mem, stof(nums[6]));
}
return max_gpu_mem;
}
} // namespace benchmark
} // namespace fastdeploy

53
fastdeploy/benchmark/utils.h Executable file
View File

@@ -0,0 +1,53 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <sys/types.h>
#include <unistd.h>
#include <cmath>
#include "fastdeploy/utils/utils.h"
namespace fastdeploy {
namespace benchmark {
// Remove the ch characters at both ends of str
std::string strip(const std::string& str, char ch = ' ') {
int i = 0;
while (str[i] == ch) {
i++;
}
int j = str.size() - 1;
while (str[j] == ch) {
j--;
}
return str.substr(i, j + 1 - i);
}
// Record current cpu memory usage into file
FASTDEPLOY_DECL void DumpCurrentCpuMemoryUsage(const std::string& name);
// Record current gpu memory usage into file
FASTDEPLOY_DECL void DumpCurrentGpuMemoryUsage(const std::string& name,
int device_id);
// Get Max cpu memory usage
FASTDEPLOY_DECL float GetCpuMemoryUsage(const std::string& name);
// Get Max gpu memory usage
FASTDEPLOY_DECL float GetGpuMemoryUsage(const std::string& name);
} // namespace benchmark
} // namespace fastdeploy

View File

@@ -81,4 +81,5 @@ struct LiteBackendOption {
nnadapter_dynamic_shape_info = {{"", {{0}}}}; nnadapter_dynamic_shape_info = {{"", {{0}}}};
std::vector<std::string> nnadapter_device_names = {}; std::vector<std::string> nnadapter_device_names = {};
}; };
} // namespace fastdeploy } // namespace fastdeploy

0
fastdeploy/runtime/runtime_option.h Normal file → Executable file
View File

0
fastdeploy/utils/utils.h Normal file → Executable file
View File