diff --git a/CMakeLists.txt b/CMakeLists.txt index 549d7b708..44ff6c786 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -240,12 +240,12 @@ if(ENABLE_PADDLE_BACKEND) add_definitions(-DENABLE_PADDLE_BACKEND) list(APPEND ALL_DEPLOY_SRCS ${DEPLOY_PADDLE_SRCS}) include(${PROJECT_SOURCE_DIR}/cmake/paddle_inference.cmake) - if(NOT APPLE) - list(APPEND DEPEND_LIBS external_paddle_inference external_dnnl external_omp) - else() - # no third parties libs(mkldnn and omp) need to - # link into paddle_inference on MacOS OSX. - list(APPEND DEPEND_LIBS external_paddle_inference) + list(APPEND DEPEND_LIBS external_paddle_inference) + if(external_dnnl_FOUND) + list(APPEND DEPEND_LIBS external_dnnl external_omp) + endif() + if(external_ort_FOUND) + list(APPEND DEPEND_LIBS external_p2o external_ort) endif() endif() @@ -387,9 +387,9 @@ if(ENABLE_TRT_BACKEND) find_package(Python COMPONENTS Interpreter Development REQUIRED) message(STATUS "Copying ${TRT_DIRECTORY}/lib to ${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt/lib ...") execute_process(COMMAND ${Python_EXECUTABLE} ${PROJECT_SOURCE_DIR}/scripts/copy_directory.py ${TRT_DIRECTORY}/lib ${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt/lib) - file(GLOB_RECURSE TRT_STAIC_LIBS ${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt/lib/*.a) + file(GLOB_RECURSE TRT_STATIC_LIBS ${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt/lib/*.a) if(TRT_STATIC_LIBS) - file(REMOVE ${TRT_STAIC_LIBS}) + file(REMOVE ${TRT_STATIC_LIBS}) endif() if(UNIX AND (NOT APPLE) AND (NOT ANDROID)) execute_process(COMMAND sh -c "ls *.so*" WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/third_libs/install/tensorrt/lib diff --git a/FastDeploy.cmake.in b/FastDeploy.cmake.in index 456a4d321..83114e901 100755 --- a/FastDeploy.cmake.in +++ b/FastDeploy.cmake.in @@ -74,10 +74,9 @@ if(ENABLE_PADDLE_BACKEND) set(DNNL_LIB "${CMAKE_CURRENT_LIST_DIR}/third_libs/install/paddle_inference/third_party/install/mkldnn/lib/libmkldnn.so.0") set(IOMP_LIB "${CMAKE_CURRENT_LIST_DIR}/third_libs/install/paddle_inference/third_party/install/mklml/lib/libiomp5.so") endif() - if(NOT APPLE) - list(APPEND FASTDEPLOY_LIBS ${PADDLE_LIB} ${DNNL_LIB} ${IOMP_LIB}) - else() - list(APPEND FASTDEPLOY_LIBS ${PADDLE_LIB}) + list(APPEND FASTDEPLOY_LIBS ${PADDLE_LIB}) + if(EXISTS "${DNNL_LIB}") + list(APPEND FASTDEPLOY_LIBS ${DNNL_LIB} ${IOMP_LIB}) endif() endif() diff --git a/benchmark/benchmark_uie.py b/benchmark/benchmark_uie.py new file mode 100644 index 000000000..44c562d7e --- /dev/null +++ b/benchmark/benchmark_uie.py @@ -0,0 +1,321 @@ +import numpy as np +import os +import time +import distutils.util +import sys +import json + +import fastdeploy as fd +from fastdeploy.text import UIEModel, SchemaLanguage + + +def parse_arguments(): + import argparse + import ast + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_dir", + required=True, + help="The directory of model and tokenizer.") + parser.add_argument( + "--data_path", required=True, help="The path of uie data.") + parser.add_argument( + "--device", + type=str, + default='cpu', + choices=['gpu', 'cpu'], + help="Type of inference device, support 'cpu' or 'gpu'.") + parser.add_argument( + "--backend", + type=str, + default='paddle', + choices=['ort', 'paddle', 'trt', 'paddle_trt', 'ov'], + help="The inference runtime backend.") + parser.add_argument( + "--device_id", type=int, default=0, help="device(gpu) id") + parser.add_argument( + "--batch_size", type=int, default=1, help="The batch size of data.") + parser.add_argument( + "--max_length", + type=int, + default=128, + help="The max length of sequence.") + parser.add_argument( + "--cpu_num_threads", + type=int, + default=8, + help="The number of threads when inferring on cpu.") + parser.add_argument( + "--enable_trt_fp16", + type=distutils.util.strtobool, + default=False, + help="whether enable fp16 in trt backend") + parser.add_argument( + "--epoch", type=int, default=1, help="The epoch of test") + parser.add_argument( + "--enable_collect_memory_info", + type=ast.literal_eval, + default=False, + help="whether enable collect memory info") + return parser.parse_args() + + +def build_option(args): + option = fd.RuntimeOption() + if args.device == 'cpu': + option.use_cpu() + option.set_cpu_thread_num(args.cpu_num_threads) + else: + option.use_gpu(args.device_id) + if args.backend == 'paddle': + option.use_paddle_backend() + elif args.backend == 'ort': + option.use_ort_backend() + elif args.backend == 'ov': + option.use_openvino_backend() + else: + option.use_trt_backend() + if args.backend == 'paddle_trt': + option.enable_paddle_to_trt() + option.enable_paddle_trt_collect_shape() + trt_file = os.path.join(args.model_dir, "infer.trt") + option.set_trt_input_shape( + 'input_ids', + min_shape=[1, 1], + opt_shape=[args.batch_size, args.max_length // 2], + max_shape=[args.batch_size, args.max_length]) + option.set_trt_input_shape( + 'token_type_ids', + min_shape=[1, 1], + opt_shape=[args.batch_size, args.max_length // 2], + max_shape=[args.batch_size, args.max_length]) + option.set_trt_input_shape( + 'pos_ids', + min_shape=[1, 1], + opt_shape=[args.batch_size, args.max_length // 2], + max_shape=[args.batch_size, args.max_length]) + option.set_trt_input_shape( + 'att_mask', + min_shape=[1, 1], + opt_shape=[args.batch_size, args.max_length // 2], + max_shape=[args.batch_size, args.max_length]) + if args.enable_trt_fp16: + option.enable_trt_fp16() + trt_file = trt_file + ".fp16" + option.set_trt_cache_file(trt_file) + return option + + +class StatBase(object): + """StatBase""" + nvidia_smi_path = "nvidia-smi" + gpu_keys = ('index', 'uuid', 'name', 'timestamp', 'memory.total', + 'memory.free', 'memory.used', 'utilization.gpu', + 'utilization.memory') + nu_opt = ',nounits' + cpu_keys = ('cpu.util', 'memory.util', 'memory.used') + + +class Monitor(StatBase): + """Monitor""" + + def __init__(self, use_gpu=False, gpu_id=0, interval=0.1): + self.result = {} + self.gpu_id = gpu_id + self.use_gpu = use_gpu + self.interval = interval + self.cpu_stat_q = multiprocessing.Queue() + + def start(self): + cmd = '%s --id=%s --query-gpu=%s --format=csv,noheader%s -lms 50' % ( + StatBase.nvidia_smi_path, self.gpu_id, ','.join(StatBase.gpu_keys), + StatBase.nu_opt) + if self.use_gpu: + self.gpu_stat_worker = subprocess.Popen( + cmd, + stderr=subprocess.STDOUT, + stdout=subprocess.PIPE, + shell=True, + close_fds=True, + preexec_fn=os.setsid) + # cpu stat + pid = os.getpid() + self.cpu_stat_worker = multiprocessing.Process( + target=self.cpu_stat_func, + args=(self.cpu_stat_q, pid, self.interval)) + self.cpu_stat_worker.start() + + def stop(self): + try: + if self.use_gpu: + os.killpg(self.gpu_stat_worker.pid, signal.SIGUSR1) + # os.killpg(p.pid, signal.SIGTERM) + self.cpu_stat_worker.terminate() + self.cpu_stat_worker.join(timeout=0.01) + except Exception as e: + print(e) + return + + # gpu + if self.use_gpu: + lines = self.gpu_stat_worker.stdout.readlines() + lines = [ + line.strip().decode("utf-8") for line in lines + if line.strip() != '' + ] + gpu_info_list = [{ + k: v + for k, v in zip(StatBase.gpu_keys, line.split(', ')) + } for line in lines] + if len(gpu_info_list) == 0: + return + result = gpu_info_list[0] + for item in gpu_info_list: + for k in item.keys(): + if k not in ["name", "uuid", "timestamp"]: + result[k] = max(int(result[k]), int(item[k])) + else: + result[k] = max(result[k], item[k]) + self.result['gpu'] = result + + # cpu + cpu_result = {} + if self.cpu_stat_q.qsize() > 0: + cpu_result = { + k: v + for k, v in zip(StatBase.cpu_keys, self.cpu_stat_q.get()) + } + while not self.cpu_stat_q.empty(): + item = { + k: v + for k, v in zip(StatBase.cpu_keys, self.cpu_stat_q.get()) + } + for k in StatBase.cpu_keys: + cpu_result[k] = max(cpu_result[k], item[k]) + cpu_result['name'] = cpuinfo.get_cpu_info()['brand_raw'] + self.result['cpu'] = cpu_result + + def output(self): + return self.result + + def cpu_stat_func(self, q, pid, interval=0.0): + """cpu stat function""" + stat_info = psutil.Process(pid) + while True: + # pid = os.getpid() + cpu_util, mem_util, mem_use = stat_info.cpu_percent( + ), stat_info.memory_percent(), round(stat_info.memory_info().rss / + 1024.0 / 1024.0, 4) + q.put([cpu_util, mem_util, mem_use]) + time.sleep(interval) + return + + +def get_dataset(data_path, max_seq_len=512): + json_lines = [] + with open(data_path, 'r', encoding='utf-8') as f: + for line in f: + json_line = json.loads(line) + content = json_line['content'].strip() + prompt = json_line['prompt'] + # Model Input is aslike: [CLS] Prompt [SEP] Content [SEP] + # It include three summary tokens. + if max_seq_len <= len(prompt) + 3: + raise ValueError( + "The value of max_seq_len is too small, please set a larger value" + ) + json_lines.append(json_line) + + return json_lines + + +if __name__ == '__main__': + args = parse_arguments() + runtime_option = build_option(args) + model_path = os.path.join(args.model_dir, "inference.pdmodel") + param_path = os.path.join(args.model_dir, "inference.pdiparams") + vocab_path = os.path.join(args.model_dir, "vocab.txt") + + gpu_id = args.device_id + enable_collect_memory_info = args.enable_collect_memory_info + dump_result = dict() + end2end_statis = list() + cpu_mem = list() + gpu_mem = list() + gpu_util = list() + if args.device == "cpu": + file_path = args.model_dir + "_model_" + args.backend + "_" + \ + args.device + "_" + str(args.cpu_num_threads) + ".txt" + else: + if args.enable_trt_fp16: + file_path = args.model_dir + "_model_" + \ + args.backend + "_fp16_" + args.device + ".txt" + else: + file_path = args.model_dir + "_model_" + args.backend + "_" + args.device + ".txt" + f = open(file_path, "w") + f.writelines("===={}====: \n".format(os.path.split(file_path)[-1][:-4])) + + ds = get_dataset(args.data_path) + schema = ["时间"] + uie = UIEModel( + model_path, + param_path, + vocab_path, + position_prob=0.5, + max_length=args.max_length, + batch_size=args.batch_size, + schema=schema, + runtime_option=runtime_option, + schema_language=SchemaLanguage.ZH) + + try: + if enable_collect_memory_info: + import multiprocessing + import subprocess + import psutil + import signal + import cpuinfo + enable_gpu = args.device == "gpu" + monitor = Monitor(enable_gpu, gpu_id) + monitor.start() + uie.enable_record_time_of_runtime() + + for ep in range(args.epoch): + for i, sample in enumerate(ds): + curr_start = time.time() + uie.set_schema([sample['prompt']]) + result = uie.predict([sample['content']]) + end2end_statis.append(time.time() - curr_start) + runtime_statis = uie.print_statis_info_of_runtime() + + warmup_iter = args.epoch * len(ds) // 5 + + end2end_statis_repeat = end2end_statis[warmup_iter:] + if enable_collect_memory_info: + monitor.stop() + mem_info = monitor.output() + dump_result["cpu_rss_mb"] = mem_info['cpu'][ + 'memory.used'] if 'cpu' in mem_info else 0 + dump_result["gpu_rss_mb"] = mem_info['gpu'][ + 'memory.used'] if 'gpu' in mem_info else 0 + dump_result["gpu_util"] = mem_info['gpu'][ + 'utilization.gpu'] if 'gpu' in mem_info else 0 + + dump_result["runtime"] = runtime_statis["avg_time"] * 1000 + dump_result["end2end"] = np.mean(end2end_statis_repeat) * 1000 + + time_cost_str = f"Runtime(ms): {dump_result['runtime']}\n" \ + f"End2End(ms): {dump_result['end2end']}\n" + f.writelines(time_cost_str) + print(time_cost_str) + + if enable_collect_memory_info: + mem_info_str = f"cpu_rss_mb: {dump_result['cpu_rss_mb']}\n" \ + f"gpu_rss_mb: {dump_result['gpu_rss_mb']}\n" \ + f"gpu_util: {dump_result['gpu_util']}\n" + f.writelines(mem_info_str) + print(mem_info_str) + except: + f.writelines("!!!!!Infer Failed\n") + + f.close() diff --git a/benchmark/run_benchmark_uie.sh b/benchmark/run_benchmark_uie.sh new file mode 100644 index 000000000..51eb5d973 --- /dev/null +++ b/benchmark/run_benchmark_uie.sh @@ -0,0 +1,27 @@ +# wget https://bj.bcebos.com/fastdeploy/benchmark/uie/reimbursement_form_data.txt +# wget https://bj.bcebos.com/fastdeploy/models/uie/uie-base.tgz +# tar xvfz uie-base.tgz + +DEVICE_ID=0 + +echo "[FastDeploy] Running UIE benchmark..." + +# GPU +echo "-------------------------------GPU Benchmark---------------------------------------" +python benchmark_uie.py --epoch 5 --model_dir uie-base --data_path reimbursement_form_data.txt --backend paddle --device_id $DEVICE_ID --device gpu --enable_collect_memory_info True +python benchmark_uie.py --epoch 5 --model_dir uie-base --data_path reimbursement_form_data.txt --backend ort --device_id $DEVICE_ID --device gpu --enable_collect_memory_info True +python benchmark_uie.py --epoch 5 --model_dir uie-base --data_path reimbursement_form_data.txt --backend paddle_trt --device_id $DEVICE_ID --device gpu --enable_trt_fp16 False --enable_collect_memory_info True +python benchmark_uie.py --epoch 5 --model_dir uie-base --data_path reimbursement_form_data.txt --backend trt --device_id $DEVICE_ID --device gpu --enable_trt_fp16 False --enable_collect_memory_info True +python benchmark_uie.py --epoch 5 --model_dir uie-base --data_path reimbursement_form_data.txt --backend paddle_trt --device_id $DEVICE_ID --device gpu --enable_trt_fp16 True --enable_collect_memory_info True +python benchmark_uie.py --epoch 5 --model_dir uie-base --data_path reimbursement_form_data.txt --backend trt --device_id $DEVICE_ID --device gpu --enable_trt_fp16 True --enable_collect_memory_info True +echo "-----------------------------------------------------------------------------------" + +# CPU +echo "-------------------------------CPU Benchmark---------------------------------------" +for cpu_num_threads in 1 8; +do + python benchmark_uie.py --epoch 5 --model_dir uie-base --data_path reimbursement_form_data.txt --backend paddle --device cpu --cpu_num_threads ${cpu_num_threads} --enable_collect_memory_info True + python benchmark_uie.py --epoch 5 --model_dir uie-base --data_path reimbursement_form_data.txt --backend ort --device cpu --cpu_num_threads ${cpu_num_threads} --enable_collect_memory_info True + python benchmark_uie.py --epoch 5 --model_dir uie-base --data_path reimbursement_form_data.txt --backend ov --device cpu --cpu_num_threads ${cpu_num_threads} --enable_collect_memory_info True +done +echo "-----------------------------------------------------------------------------------" diff --git a/cmake/paddle_inference.cmake b/cmake/paddle_inference.cmake index 3822f9ac3..e33a14eb6 100644 --- a/cmake/paddle_inference.cmake +++ b/cmake/paddle_inference.cmake @@ -40,16 +40,24 @@ if(WIN32) CACHE FILEPATH "paddle_inference compile library." FORCE) set(DNNL_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/mkldnn/lib/mkldnn.lib") set(OMP_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/mklml/lib/libiomp5md.lib") + set(P2O_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/paddle2onnx/lib/paddle2onnx.lib") + set(ORT_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/onnxruntime/lib/onnxruntime.lib") elseif(APPLE) set(PADDLEINFERENCE_COMPILE_LIB "${PADDLEINFERENCE_INSTALL_DIR}/paddle/lib/libpaddle_inference.dylib" CACHE FILEPATH "paddle_inference compile library." FORCE) + set(DNNL_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/mkldnn/lib/libdnnl.so.2") + set(OMP_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/mklml/lib/libiomp5.so") + set(P2O_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/paddle2onnx/lib/libpaddle2onnx.dylib") + set(ORT_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/onnxruntime/lib/libonnxruntime.dylib") else() set(PADDLEINFERENCE_COMPILE_LIB "${PADDLEINFERENCE_INSTALL_DIR}/paddle/lib/libpaddle_inference.so" CACHE FILEPATH "paddle_inference compile library." FORCE) set(DNNL_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/mkldnn/lib/libdnnl.so.2") set(OMP_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/mklml/lib/libiomp5.so") + set(P2O_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/paddle2onnx/lib/libpaddle2onnx.so") + set(ORT_LIB "${PADDLEINFERENCE_INSTALL_DIR}/third_party/install/onnxruntime/lib/libonnxruntime.so") endif(WIN32) @@ -116,16 +124,23 @@ set_property(TARGET external_paddle_inference PROPERTY IMPORTED_LOCATION ${PADDLEINFERENCE_COMPILE_LIB}) add_dependencies(external_paddle_inference ${PADDLEINFERENCE_PROJECT}) -if (NOT APPLE) - # no third parties libs(mkldnn and omp) need to - # link into paddle_inference on MacOS OSX. - add_library(external_dnnl STATIC IMPORTED GLOBAL) - set_property(TARGET external_dnnl PROPERTY IMPORTED_LOCATION - ${DNNL_LIB}) - add_dependencies(external_dnnl ${PADDLEINFERENCE_PROJECT}) - add_library(external_omp STATIC IMPORTED GLOBAL) - set_property(TARGET external_omp PROPERTY IMPORTED_LOCATION - ${OMP_LIB}) - add_dependencies(external_omp ${PADDLEINFERENCE_PROJECT}) -endif() +add_library(external_p2o STATIC IMPORTED GLOBAL) +set_property(TARGET external_p2o PROPERTY IMPORTED_LOCATION + ${P2O_LIB}) +add_dependencies(external_p2o ${PADDLEINFERENCE_PROJECT}) + +add_library(external_ort STATIC IMPORTED GLOBAL) +set_property(TARGET external_ort PROPERTY IMPORTED_LOCATION + ${ORT_LIB}) +add_dependencies(external_ort ${PADDLEINFERENCE_PROJECT}) + +add_library(external_dnnl STATIC IMPORTED GLOBAL) +set_property(TARGET external_dnnl PROPERTY IMPORTED_LOCATION + ${DNNL_LIB}) +add_dependencies(external_dnnl ${PADDLEINFERENCE_PROJECT}) + +add_library(external_omp STATIC IMPORTED GLOBAL) +set_property(TARGET external_omp PROPERTY IMPORTED_LOCATION + ${OMP_LIB}) +add_dependencies(external_omp ${PADDLEINFERENCE_PROJECT}) diff --git a/cmake/toolchain.cmake b/cmake/toolchain.cmake index 4b3485748..85bd05798 100755 --- a/cmake/toolchain.cmake +++ b/cmake/toolchain.cmake @@ -10,7 +10,7 @@ if (DEFINED TARGET_ABI) set(OPENCV_URL "https://bj.bcebos.com/fastdeploy/third_libs/opencv-linux-armv7hf-4.6.0.tgz") set(OPENCV_FILENAME "opencv-linux-armv7hf-4.6.0") if(WITH_TIMVX) - set(PADDLELITE_URL "https://bj.bcebos.com/fastdeploy/third_libs/lite-linux-armhf-timvx-1130.tgz") + set(PADDLELITE_URL "https://bj.bcebos.com/fastdeploy/third_libs/lite-linux-armhf-timvx-20221229.tgz") else() message(STATUS "PADDLELITE_URL will be configured if WITH_TIMVX=ON.") endif() diff --git a/examples/text/uie/serving/models/uie/1/model.py b/examples/text/uie/serving/models/uie/1/model.py index 5bb1c8164..b839ae065 100644 --- a/examples/text/uie/serving/models/uie/1/model.py +++ b/examples/text/uie/serving/models/uie/1/model.py @@ -141,7 +141,7 @@ class TritonPythonModel: self.uie_model_.set_schema(schema) results = self.uie_model_.predict(texts, return_dict=True) - results = np.array(results, dtype=np.object) + results = np.array(results, dtype=np.object_) out_tensor = pb_utils.Tensor(self.output_names[0], results) inference_response = pb_utils.InferenceResponse( output_tensors=[out_tensor, ]) diff --git a/examples/vision/classification/paddleclas/rknpu2/cpp/README.md b/examples/vision/classification/paddleclas/rknpu2/cpp/README.md index 1e1883486..c21d1d77b 100644 --- a/examples/vision/classification/paddleclas/rknpu2/cpp/README.md +++ b/examples/vision/classification/paddleclas/rknpu2/cpp/README.md @@ -64,8 +64,8 @@ cd ./build/install ## 运行结果展示 ClassifyResult( -label_ids: 153, -scores: 0.684570, +label_ids: 153, +scores: 0.684570, ) ## 注意事项 @@ -75,4 +75,4 @@ DisablePermute(C++)或`disable_permute(Python),在预处理阶段禁用数据 ## 其它文档 - [ResNet50_vd Python 部署](../python) - [模型预测结果说明](../../../../../../docs/api/vision_results/) -- [转换ResNet50_vd RKNN模型文档](../README.md) \ No newline at end of file +- [转换ResNet50_vd RKNN模型文档](../README.md) diff --git a/examples/vision/classification/paddleclas/rknpu2/python/README.md b/examples/vision/classification/paddleclas/rknpu2/python/README.md index b85bb81f7..f1f0994d8 100644 --- a/examples/vision/classification/paddleclas/rknpu2/python/README.md +++ b/examples/vision/classification/paddleclas/rknpu2/python/README.md @@ -19,8 +19,8 @@ python3 infer.py --model_file ./ResNet50_vd_infer/ResNet50_vd_infer_rk3588.rknn # 运行完成后返回结果如下所示 ClassifyResult( -label_ids: 153, -scores: 0.684570, +label_ids: 153, +scores: 0.684570, ) ``` @@ -32,4 +32,4 @@ DisablePermute(C++)或`disable_permute(Python),在预处理阶段禁用数据 ## 其它文档 - [ResNet50_vd C++部署](../cpp) - [模型预测结果说明](../../../../../../docs/api/vision_results/) -- [转换ResNet50_vd RKNN模型文档](../README.md) \ No newline at end of file +- [转换ResNet50_vd RKNN模型文档](../README.md) diff --git a/examples/vision/classification/paddleclas/serving/models/postprocess/1/model.py b/examples/vision/classification/paddleclas/serving/models/postprocess/1/model.py index 0ab7dcdc4..de000f6ee 100755 --- a/examples/vision/classification/paddleclas/serving/models/postprocess/1/model.py +++ b/examples/vision/classification/paddleclas/serving/models/postprocess/1/model.py @@ -92,7 +92,7 @@ class TritonPythonModel: results = self.postprocess_.run([infer_outputs, ]) r_str = fd.vision.utils.fd_result_to_json(results) - r_np = np.array(r_str, dtype=np.object) + r_np = np.array(r_str, dtype=np.object_) out_tensor = pb_utils.Tensor(self.output_names[0], r_np) inference_response = pb_utils.InferenceResponse( output_tensors=[out_tensor, ]) diff --git a/examples/vision/detection/nanodet_plus/python/README.md b/examples/vision/detection/nanodet_plus/python/README.md index b5085662c..a89e15d1b 100644 --- a/examples/vision/detection/nanodet_plus/python/README.md +++ b/examples/vision/detection/nanodet_plus/python/README.md @@ -69,7 +69,7 @@ NanoDetPlus模型加载和初始化,其中model_file为导出的ONNX模型格 > > * **padding_value**(list[float]): 通过此参数可以修改图片在resize时候做填充(padding)的值, 包含三个浮点型元素, 分别表示三个通道的值, 默认值为[0, 0, 0] > > * **keep_ratio**(bool): 通过此参数指定resize时是否保持宽高比例不变,默认是fasle. > > * **reg_max**(int): GFL回归中的reg_max参数,默认是7. -> > * **downsample_strides**(list[int]): 通过此参数可以修改生成anchor的特征图的下采样倍数, 包含三个整型元素, 分别表示默认的生成anchor的下采样倍数, 默认值为[8, 16, 32, 64] +> > * **downsample_strides**(list[int]): 通过此参数可以修改生成anchor的特征图的下采样倍数, 包含四个整型元素, 分别表示默认的生成anchor的下采样倍数, 默认值为[8, 16, 32, 64] diff --git a/examples/vision/detection/paddledetection/serving/models/postprocess/1/model.py b/examples/vision/detection/paddledetection/serving/models/postprocess/1/model.py index 4872b0dee..35054e516 100644 --- a/examples/vision/detection/paddledetection/serving/models/postprocess/1/model.py +++ b/examples/vision/detection/paddledetection/serving/models/postprocess/1/model.py @@ -95,7 +95,7 @@ class TritonPythonModel: results = self.postprocess_.run(infer_outputs) r_str = fd.vision.utils.fd_result_to_json(results) - r_np = np.array(r_str, dtype=np.object) + r_np = np.array(r_str, dtype=np.object_) out_tensor = pb_utils.Tensor(self.output_names[0], r_np) inference_response = pb_utils.InferenceResponse( output_tensors=[out_tensor, ]) diff --git a/examples/vision/detection/yolov5/serving/models/postprocess/1/model.py b/examples/vision/detection/yolov5/serving/models/postprocess/1/model.py index 7c608db43..1204446c4 100644 --- a/examples/vision/detection/yolov5/serving/models/postprocess/1/model.py +++ b/examples/vision/detection/yolov5/serving/models/postprocess/1/model.py @@ -96,7 +96,7 @@ class TritonPythonModel: results = self.postprocessor_.run([infer_outputs], im_infos) r_str = fd.vision.utils.fd_result_to_json(results) - r_np = np.array(r_str, dtype=np.object) + r_np = np.array(r_str, dtype=np.object_) out_tensor = pb_utils.Tensor(self.output_names[0], r_np) inference_response = pb_utils.InferenceResponse( diff --git a/examples/vision/detection/yolov5/serving/models/preprocess/1/model.py b/examples/vision/detection/yolov5/serving/models/preprocess/1/model.py index cf4f7e8e8..d60de6541 100644 --- a/examples/vision/detection/yolov5/serving/models/preprocess/1/model.py +++ b/examples/vision/detection/yolov5/serving/models/preprocess/1/model.py @@ -95,7 +95,7 @@ class TritonPythonModel: dlpack_tensor) output_tensor_1 = pb_utils.Tensor( self.output_names[1], np.array( - im_infos, dtype=np.object)) + im_infos, dtype=np.object_)) inference_response = pb_utils.InferenceResponse( output_tensors=[output_tensor_0, output_tensor_1]) responses.append(inference_response) diff --git a/examples/vision/ocr/PP-OCRv3/serving/models/det_postprocess/1/model.py b/examples/vision/ocr/PP-OCRv3/serving/models/det_postprocess/1/model.py index faaca9067..9cfe2583e 100644 --- a/examples/vision/ocr/PP-OCRv3/serving/models/det_postprocess/1/model.py +++ b/examples/vision/ocr/PP-OCRv3/serving/models/det_postprocess/1/model.py @@ -217,7 +217,7 @@ class TritonPythonModel: out_tensor_0 = pb_utils.Tensor( self.output_names[0], np.array( - batch_rec_texts, dtype=np.object)) + batch_rec_texts, dtype=np.object_)) out_tensor_1 = pb_utils.Tensor(self.output_names[1], np.array(batch_rec_scores)) inference_response = pb_utils.InferenceResponse( diff --git a/examples/vision/ocr/PP-OCRv3/serving/models/rec_postprocess/1/model.py b/examples/vision/ocr/PP-OCRv3/serving/models/rec_postprocess/1/model.py index fe66e8c3f..c046cd929 100755 --- a/examples/vision/ocr/PP-OCRv3/serving/models/rec_postprocess/1/model.py +++ b/examples/vision/ocr/PP-OCRv3/serving/models/rec_postprocess/1/model.py @@ -96,7 +96,7 @@ class TritonPythonModel: results = self.postprocessor.run([infer_outputs]) out_tensor_0 = pb_utils.Tensor( self.output_names[0], np.array( - results[0], dtype=np.object)) + results[0], dtype=np.object_)) out_tensor_1 = pb_utils.Tensor(self.output_names[1], np.array(results[1])) inference_response = pb_utils.InferenceResponse( diff --git a/examples/vision/segmentation/paddleseg/rknpu2/cpp/infer.cc b/examples/vision/segmentation/paddleseg/rknpu2/cpp/infer.cc index 834b2ccb3..f80d3fc8f 100644 --- a/examples/vision/segmentation/paddleseg/rknpu2/cpp/infer.cc +++ b/examples/vision/segmentation/paddleseg/rknpu2/cpp/infer.cc @@ -62,7 +62,8 @@ void RKNPU2Infer(const std::string& model_dir, const std::string& image_file) { std::cerr << "Failed to initialize." << std::endl; return; } - model.GetPreprocessor().DisableNormalizeAndPermute(); + model.GetPreprocessor().DisablePermute(); + model.GetPreprocessor().DisableNormalize(); fastdeploy::TimeCounter tc; tc.Start(); diff --git a/examples/vision/segmentation/paddleseg/rknpu2/python/infer.py b/examples/vision/segmentation/paddleseg/rknpu2/python/infer.py index 4168d591d..193a6dfb9 100644 --- a/examples/vision/segmentation/paddleseg/rknpu2/python/infer.py +++ b/examples/vision/segmentation/paddleseg/rknpu2/python/infer.py @@ -49,7 +49,8 @@ model = fd.vision.segmentation.PaddleSegModel( runtime_option=runtime_option, model_format=fd.ModelFormat.RKNN) -model.preprocessor.disable_normalize_and_permute() +model.preprocessor.disable_normalize() +model.preprocessor.disable_permute() # 预测图片分割结果 im = cv2.imread(args.image) diff --git a/fastdeploy/vision/detection/ppdet/model.h b/fastdeploy/vision/detection/ppdet/model.h index be13b0b4d..17502cf21 100755 --- a/fastdeploy/vision/detection/ppdet/model.h +++ b/fastdeploy/vision/detection/ppdet/model.h @@ -68,6 +68,7 @@ class FASTDEPLOY_DECL PPYOLOE : public PPDetBase { valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; valid_timvx_backends = {Backend::LITE}; valid_kunlunxin_backends = {Backend::LITE}; + valid_rknpu_backends = {Backend::RKNPU2}; valid_ascend_backends = {Backend::LITE}; initialized = Initialize(); } diff --git a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc index 09c89dfce..573164910 100644 --- a/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc +++ b/fastdeploy/vision/detection/ppdet/ppdet_pybind.cc @@ -31,7 +31,13 @@ void BindPPDet(pybind11::module& m) { outputs[i].StopSharing(); } return outputs; - }); + }) + .def("disable_normalize", [](vision::detection::PaddleDetPreprocessor& self) { + self.DisableNormalize(); + }) + .def("disable_permute", [](vision::detection::PaddleDetPreprocessor& self) { + self.DisablePermute(); + });; pybind11::class_( m, "PaddleDetPostprocessor") diff --git a/fastdeploy/vision/detection/ppdet/preprocessor.cc b/fastdeploy/vision/detection/ppdet/preprocessor.cc index bb38c67ec..a18d43b70 100644 --- a/fastdeploy/vision/detection/ppdet/preprocessor.cc +++ b/fastdeploy/vision/detection/ppdet/preprocessor.cc @@ -22,19 +22,19 @@ namespace vision { namespace detection { PaddleDetPreprocessor::PaddleDetPreprocessor(const std::string& config_file) { - FDASSERT(BuildPreprocessPipelineFromConfig(config_file), + this->config_file_ = config_file; + FDASSERT(BuildPreprocessPipelineFromConfig(), "Failed to create PaddleDetPreprocessor."); initialized_ = true; } -bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig( - const std::string& config_file) { +bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig() { processors_.clear(); YAML::Node cfg; try { - cfg = YAML::LoadFile(config_file); + cfg = YAML::LoadFile(config_file_); } catch (YAML::BadFile& e) { - FDERROR << "Failed to load yaml file " << config_file + FDERROR << "Failed to load yaml file " << config_file_ << ", maybe you should check this file." << std::endl; return false; } @@ -45,21 +45,23 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig( for (const auto& op : cfg["Preprocess"]) { std::string op_name = op["type"].as(); if (op_name == "NormalizeImage") { - auto mean = op["mean"].as>(); - auto std = op["std"].as>(); - bool is_scale = true; - if (op["is_scale"]) { - is_scale = op["is_scale"].as(); + if (!disable_normalize_) { + auto mean = op["mean"].as>(); + auto std = op["std"].as>(); + bool is_scale = true; + if (op["is_scale"]) { + is_scale = op["is_scale"].as(); + } + std::string norm_type = "mean_std"; + if (op["norm_type"]) { + norm_type = op["norm_type"].as(); + } + if (norm_type != "mean_std") { + std::fill(mean.begin(), mean.end(), 0.0); + std::fill(std.begin(), std.end(), 1.0); + } + processors_.push_back(std::make_shared(mean, std, is_scale)); } - std::string norm_type = "mean_std"; - if (op["norm_type"]) { - norm_type = op["norm_type"].as(); - } - if (norm_type != "mean_std") { - std::fill(mean.begin(), mean.end(), 0.0); - std::fill(std.begin(), std.end(), 1.0); - } - processors_.push_back(std::make_shared(mean, std, is_scale)); } else if (op_name == "Resize") { bool keep_ratio = op["keep_ratio"].as(); auto target_size = op["target_size"].as>(); @@ -104,10 +106,12 @@ bool PaddleDetPreprocessor::BuildPreprocessPipelineFromConfig( return false; } } - if (has_permute) { - // permute = cast + HWC2CHW - processors_.push_back(std::make_shared("float")); - processors_.push_back(std::make_shared()); + if (!disable_permute_) { + if (has_permute) { + // permute = cast + HWC2CHW + processors_.push_back(std::make_shared("float")); + processors_.push_back(std::make_shared()); + } } // Fusion will improve performance @@ -202,7 +206,20 @@ bool PaddleDetPreprocessor::Run(std::vector* images, return true; } - +void PaddleDetPreprocessor::DisableNormalize() { + this->disable_normalize_ = true; + // the DisableNormalize function will be invalid if the configuration file is loaded during preprocessing + if (!BuildPreprocessPipelineFromConfig()) { + FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; + } +} +void PaddleDetPreprocessor::DisablePermute() { + this->disable_permute_ = true; + // the DisablePermute function will be invalid if the configuration file is loaded during preprocessing + if (!BuildPreprocessPipelineFromConfig()) { + FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; + } +} } // namespace detection } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/detection/ppdet/preprocessor.h b/fastdeploy/vision/detection/ppdet/preprocessor.h index 2733c450e..8371afb69 100644 --- a/fastdeploy/vision/detection/ppdet/preprocessor.h +++ b/fastdeploy/vision/detection/ppdet/preprocessor.h @@ -39,10 +39,21 @@ class FASTDEPLOY_DECL PaddleDetPreprocessor { */ bool Run(std::vector* images, std::vector* outputs); + /// This function will disable normalize in preprocessing step. + void DisableNormalize(); + /// This function will disable hwc2chw in preprocessing step. + void DisablePermute(); + private: - bool BuildPreprocessPipelineFromConfig(const std::string& config_file); + bool BuildPreprocessPipelineFromConfig(); std::vector> processors_; bool initialized_ = false; + // for recording the switch of hwc2chw + bool disable_permute_ = false; + // for recording the switch of normalize + bool disable_normalize_ = false; + // read config file + std::string config_file_; }; } // namespace detection diff --git a/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc b/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc index e687d3cc4..78c7c9ccc 100644 --- a/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc +++ b/fastdeploy/vision/segmentation/ppseg/ppseg_pybind.cc @@ -36,9 +36,12 @@ void BindPPSeg(pybind11::module& m) { } return make_pair(outputs, imgs_info);; }) - .def("disable_normalize_and_permute", - &vision::segmentation::PaddleSegPreprocessor::DisableNormalizeAndPermute) - + .def("disable_normalize", [](vision::segmentation::PaddleSegPreprocessor& self) { + self.DisableNormalize(); + }) + .def("disable_permute", [](vision::segmentation::PaddleSegPreprocessor& self) { + self.DisablePermute(); + }) .def_property("is_vertical_screen", &vision::segmentation::PaddleSegPreprocessor::GetIsVerticalScreen, &vision::segmentation::PaddleSegPreprocessor::SetIsVerticalScreen); diff --git a/fastdeploy/vision/segmentation/ppseg/preprocessor.cc b/fastdeploy/vision/segmentation/ppseg/preprocessor.cc index 027309aad..92b037895 100644 --- a/fastdeploy/vision/segmentation/ppseg/preprocessor.cc +++ b/fastdeploy/vision/segmentation/ppseg/preprocessor.cc @@ -43,7 +43,7 @@ bool PaddleSegPreprocessor::BuildPreprocessPipelineFromConfig() { FDASSERT(op.IsMap(), "Require the transform information in yaml be Map type."); if (op["type"].as() == "Normalize") { - if (!disable_normalize_and_permute_) { + if (!disable_normalize_) { std::vector mean = {0.5, 0.5, 0.5}; std::vector std = {0.5, 0.5, 0.5}; if (op["mean"]) { @@ -55,7 +55,7 @@ bool PaddleSegPreprocessor::BuildPreprocessPipelineFromConfig() { processors_.push_back(std::make_shared(mean, std)); } } else if (op["type"].as() == "Resize") { - is_contain_resize_op = true; + is_contain_resize_op_ = true; const auto& target_size = op["target_size"]; int resize_width = target_size[0].as(); int resize_height = target_size[1].as(); @@ -73,13 +73,13 @@ bool PaddleSegPreprocessor::BuildPreprocessPipelineFromConfig() { auto input_shape = cfg["Deploy"]["input_shape"]; int input_height = input_shape[2].as(); int input_width = input_shape[3].as(); - if (input_height != -1 && input_width != -1 && !is_contain_resize_op) { - is_contain_resize_op = true; + if (input_height != -1 && input_width != -1 && !is_contain_resize_op_) { + is_contain_resize_op_ = true; processors_.insert(processors_.begin(), std::make_shared(input_width, input_height)); } } - if (!disable_normalize_and_permute_) { + if (!disable_permute_) { processors_.push_back(std::make_shared()); } @@ -121,7 +121,7 @@ bool PaddleSegPreprocessor::Run(std::vector* images, std::vectorsize(); // Batch preprocess : resize all images to the largest image shape in batch - if (!is_contain_resize_op && img_num > 1) { + if (!is_contain_resize_op_ && img_num > 1) { int max_width = 0; int max_height = 0; for (size_t i = 0; i < img_num; ++i) { @@ -156,14 +156,20 @@ bool PaddleSegPreprocessor::Run(std::vector* images, std::vectordisable_normalize_ = true; + // the DisableNormalize function will be invalid if the configuration file is loaded during preprocessing + if (!BuildPreprocessPipelineFromConfig()) { + FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; + } +} +void PaddleSegPreprocessor::DisablePermute() { + this->disable_permute_ = true; + // the DisablePermute function will be invalid if the configuration file is loaded during preprocessing if (!BuildPreprocessPipelineFromConfig()) { FDERROR << "Failed to build preprocess pipeline from configuration file." << std::endl; } } - } // namespace segmentation } // namespace vision } // namespace fastdeploy diff --git a/fastdeploy/vision/segmentation/ppseg/preprocessor.h b/fastdeploy/vision/segmentation/ppseg/preprocessor.h index faa7fb8de..6452e8e0e 100644 --- a/fastdeploy/vision/segmentation/ppseg/preprocessor.h +++ b/fastdeploy/vision/segmentation/ppseg/preprocessor.h @@ -49,8 +49,10 @@ class FASTDEPLOY_DECL PaddleSegPreprocessor { is_vertical_screen_ = value; } - // This function will disable normalize and hwc2chw in preprocessing step. - void DisableNormalizeAndPermute(); + /// This function will disable normalize in preprocessing step. + void DisableNormalize(); + /// This function will disable hwc2chw in preprocessing step. + void DisablePermute(); private: virtual bool BuildPreprocessPipelineFromConfig(); @@ -61,10 +63,12 @@ class FASTDEPLOY_DECL PaddleSegPreprocessor { */ bool is_vertical_screen_ = false; - // for recording the switch of normalize and hwc2chw - bool disable_normalize_and_permute_ = false; + // for recording the switch of hwc2chw + bool disable_permute_ = false; + // for recording the switch of normalize + bool disable_normalize_ = false; - bool is_contain_resize_op = false; + bool is_contain_resize_op_ = false; bool initialized_ = false; }; diff --git a/python/fastdeploy/vision/detection/ppdet/__init__.py b/python/fastdeploy/vision/detection/ppdet/__init__.py index 45734eef0..f9b162aca 100644 --- a/python/fastdeploy/vision/detection/ppdet/__init__.py +++ b/python/fastdeploy/vision/detection/ppdet/__init__.py @@ -36,6 +36,18 @@ class PaddleDetPreprocessor: """ return self._preprocessor.run(input_ims) + def disable_normalize(self): + """ + This function will disable normalize in preprocessing step. + """ + self._preprocessor.disable_normalize() + + def disable_permute(self): + """ + This function will disable hwc2chw in preprocessing step. + """ + self._preprocessor.disable_permute() + class PaddleDetPostprocessor: def __init__(self): @@ -500,4 +512,4 @@ class RTMDet(PPYOLOE): self._model = C.vision.detection.RTMDet( model_file, params_file, config_file, self._runtime_option, model_format) - assert self.initialized, "RTMDet model initialize failed." \ No newline at end of file + assert self.initialized, "RTMDet model initialize failed." diff --git a/python/fastdeploy/vision/segmentation/ppseg/__init__.py b/python/fastdeploy/vision/segmentation/ppseg/__init__.py index 455785686..f0106a39a 100644 --- a/python/fastdeploy/vision/segmentation/ppseg/__init__.py +++ b/python/fastdeploy/vision/segmentation/ppseg/__init__.py @@ -104,10 +104,17 @@ class PaddleSegPreprocessor: """ return self._preprocessor.run(input_ims) - def disable_normalize_and_permute(self): - """To disable normalize and hwc2chw in preprocessing step. + def disable_normalize(self): """ - return self._preprocessor.disable_normalize_and_permute() + This function will disable normalize in preprocessing step. + """ + self._preprocessor.disable_normalize() + + def disable_permute(self): + """ + This function will disable hwc2chw in preprocessing step. + """ + self._preprocessor.disable_permute() @property def is_vertical_screen(self):