mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Backend] Add collect shape for pp-trt backend (#372)
* Add collect_shape attr * add EnableTunedTensorRtDynamicShape * Add collect shape python api * Fix quant model not set trt dynamic shape * Add shape info print * Fix shape print * Use CopyFromCpu instead of ShareExternalData * Add ENABLE_TRT_BACKEND macro * Add shared data with
This commit is contained in:
@@ -13,6 +13,8 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "fastdeploy/backends/paddle/paddle_backend.h"
|
#include "fastdeploy/backends/paddle/paddle_backend.h"
|
||||||
|
#include "fastdeploy/utils/path.h"
|
||||||
|
#include <sstream>
|
||||||
|
|
||||||
namespace fastdeploy {
|
namespace fastdeploy {
|
||||||
|
|
||||||
@@ -31,21 +33,7 @@ void PaddleBackend::BuildOption(const PaddleBackendOption& option) {
|
|||||||
use_static = true;
|
use_static = true;
|
||||||
}
|
}
|
||||||
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, 32, 3, precision, use_static);
|
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, 32, 3, precision, use_static);
|
||||||
std::map<std::string, std::vector<int>> max_shape;
|
SetTRTDynamicShapeToConfig(option);
|
||||||
std::map<std::string, std::vector<int>> min_shape;
|
|
||||||
std::map<std::string, std::vector<int>> opt_shape;
|
|
||||||
for (const auto& item : option.trt_option.min_shape) {
|
|
||||||
auto max_iter = option.trt_option.max_shape.find(item.first);
|
|
||||||
auto opt_iter = option.trt_option.opt_shape.find(item.first);
|
|
||||||
FDASSERT(max_iter != option.trt_option.max_shape.end(), "Cannot find %s in TrtBackendOption::min_shape.", item.first.c_str());
|
|
||||||
FDASSERT(opt_iter != option.trt_option.opt_shape.end(), "Cannot find %s in TrtBackendOption::opt_shape.", item.first.c_str());
|
|
||||||
max_shape[item.first].assign(max_iter->second.begin(), max_iter->second.end());
|
|
||||||
opt_shape[item.first].assign(opt_iter->second.begin(), opt_iter->second.end());
|
|
||||||
min_shape[item.first].assign(item.second.begin(), item.second.end());
|
|
||||||
}
|
|
||||||
if (min_shape.size() > 0) {
|
|
||||||
config_.SetTRTDynamicShapeInfo(min_shape, max_shape, opt_shape);
|
|
||||||
}
|
|
||||||
#else
|
#else
|
||||||
FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so will fallback to GPU with Paddle Inference Backend." << std::endl;
|
FDWARNING << "The FastDeploy is not compiled with TensorRT backend, so will fallback to GPU with Paddle Inference Backend." << std::endl;
|
||||||
#endif
|
#endif
|
||||||
@@ -97,6 +85,17 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
if (reader.is_quantize_model) {
|
if (reader.is_quantize_model) {
|
||||||
if (option.use_gpu) {
|
if (option.use_gpu) {
|
||||||
FDWARNING << "The loaded model is a quantized model, while inference on GPU, please use TensorRT backend to get better performance." << std::endl;
|
FDWARNING << "The loaded model is a quantized model, while inference on GPU, please use TensorRT backend to get better performance." << std::endl;
|
||||||
|
if (option.enable_trt) {
|
||||||
|
#ifdef ENABLE_TRT_BACKEND
|
||||||
|
bool use_static = false;
|
||||||
|
if (option.trt_option.serialize_file != "") {
|
||||||
|
FDWARNING << "Detect that tensorrt cache file has been set to " << option.trt_option.serialize_file << ", but while enable paddle2trt, please notice that the cache file will save to the directory where paddle model saved." << std::endl;
|
||||||
|
use_static = true;
|
||||||
|
}
|
||||||
|
config_.EnableTensorRtEngine(option.trt_option.max_workspace_size, 32, 3, paddle_infer::PrecisionType::kInt8, use_static, false);
|
||||||
|
SetTRTDynamicShapeToConfig(option);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (option.enable_mkldnn) {
|
if (option.enable_mkldnn) {
|
||||||
config_.EnableMkldnnInt8();
|
config_.EnableMkldnnInt8();
|
||||||
@@ -123,7 +122,31 @@ bool PaddleBackend::InitFromPaddle(const std::string& model_file,
|
|||||||
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
|
||||||
outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype);
|
outputs_desc_[i].dtype = ReaderDataTypeToFD(reader.outputs[i].dtype);
|
||||||
}
|
}
|
||||||
|
#ifdef ENABLE_TRT_BACKEND
|
||||||
|
if (option.collect_shape) {
|
||||||
|
// Set the shape info file.
|
||||||
|
auto curr_model_dir = GetDirFromPath(model_file);
|
||||||
|
std::string shape_range_info = PathJoin(curr_model_dir, "shape_range_info.pbtxt");
|
||||||
|
if (!CheckFileExists(shape_range_info)) {
|
||||||
|
FDINFO << "Start generating shape range info file." << std::endl;
|
||||||
|
paddle_infer::Config analysis_config;
|
||||||
|
analysis_config.SetModel(model_file, params_file);
|
||||||
|
analysis_config.CollectShapeRangeInfo(shape_range_info);
|
||||||
|
auto predictor_tmp = paddle_infer::CreatePredictor(analysis_config);
|
||||||
|
std::map<std::string, std::vector<int>> max_shape;
|
||||||
|
std::map<std::string, std::vector<int>> min_shape;
|
||||||
|
std::map<std::string, std::vector<int>> opt_shape;
|
||||||
|
GetDynamicShapeFromOption(option, &max_shape, &min_shape, &opt_shape);
|
||||||
|
// Need to run once to get the shape range info file.
|
||||||
|
CollectShapeRun(predictor_tmp.get(), max_shape);
|
||||||
|
CollectShapeRun(predictor_tmp.get(), min_shape);
|
||||||
|
CollectShapeRun(predictor_tmp.get(), opt_shape);
|
||||||
|
FDINFO << "Finish generating shape range info file." << std::endl;
|
||||||
|
}
|
||||||
|
FDINFO << "Start loading shape range info file "<< shape_range_info << " to set TensorRT dynamic shape." << std::endl;
|
||||||
|
config_.EnableTunedTensorRtDynamicShape(shape_range_info, false);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
predictor_ = paddle_infer::CreatePredictor(config_);
|
predictor_ = paddle_infer::CreatePredictor(config_);
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
return true;
|
return true;
|
||||||
@@ -172,4 +195,87 @@ bool PaddleBackend::Infer(std::vector<FDTensor>& inputs,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef ENABLE_TRT_BACKEND
|
||||||
|
void PaddleBackend::SetTRTDynamicShapeToConfig(const PaddleBackendOption& option) {
|
||||||
|
std::map<std::string, std::vector<int>> max_shape;
|
||||||
|
std::map<std::string, std::vector<int>> min_shape;
|
||||||
|
std::map<std::string, std::vector<int>> opt_shape;
|
||||||
|
GetDynamicShapeFromOption(option, &max_shape, &min_shape, &opt_shape);
|
||||||
|
FDINFO << "Start setting trt dynamic shape." << std::endl;
|
||||||
|
if (min_shape.size() > 0) {
|
||||||
|
config_.SetTRTDynamicShapeInfo(min_shape, max_shape, opt_shape);
|
||||||
|
}
|
||||||
|
FDINFO << "Finish setting trt dynamic shape." << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
void PaddleBackend::GetDynamicShapeFromOption(const PaddleBackendOption& option,
|
||||||
|
std::map<std::string, std::vector<int>>* max_shape,
|
||||||
|
std::map<std::string, std::vector<int>>* min_shape,
|
||||||
|
std::map<std::string, std::vector<int>>* opt_shape) const {
|
||||||
|
auto print_shape = [](const std::vector<int>& shape) -> std::string {
|
||||||
|
std::ostringstream oss;
|
||||||
|
oss << "[";
|
||||||
|
for (int i = 0; i < shape.size(); ++i) {
|
||||||
|
oss << shape[i];
|
||||||
|
if (i < shape.size() - 1) {
|
||||||
|
oss << ", ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
oss << "]";
|
||||||
|
return oss.str();
|
||||||
|
};
|
||||||
|
for (const auto& item : option.trt_option.min_shape) {
|
||||||
|
auto max_iter = option.trt_option.max_shape.find(item.first);
|
||||||
|
auto opt_iter = option.trt_option.opt_shape.find(item.first);
|
||||||
|
FDASSERT(max_iter != option.trt_option.max_shape.end(), "Cannot find %s in TrtBackendOption::min_shape.", item.first.c_str());
|
||||||
|
FDASSERT(opt_iter != option.trt_option.opt_shape.end(), "Cannot find %s in TrtBackendOption::opt_shape.", item.first.c_str());
|
||||||
|
(*max_shape)[item.first].assign(max_iter->second.begin(), max_iter->second.end());
|
||||||
|
(*opt_shape)[item.first].assign(opt_iter->second.begin(), opt_iter->second.end());
|
||||||
|
(*min_shape)[item.first].assign(item.second.begin(), item.second.end());
|
||||||
|
FDINFO << item.first << ": the max shape = " << print_shape(max_iter->second)
|
||||||
|
<< ", the min shape = " << print_shape(item.second)
|
||||||
|
<< ", the opt shape = " << print_shape(opt_iter->second) << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void PaddleBackend::CollectShapeRun(paddle_infer::Predictor* predictor,
|
||||||
|
const std::map<std::string, std::vector<int>>& shape) const {
|
||||||
|
auto input_names = predictor->GetInputNames();
|
||||||
|
auto input_type = predictor->GetInputTypes();
|
||||||
|
for(auto name : input_names) {
|
||||||
|
FDASSERT(shape.find(name) != shape.end() && input_type.find(name) != input_type.end(),
|
||||||
|
"Paddle Input name [%s] is not one of the trt dynamic shape.", name.c_str());
|
||||||
|
auto tensor = predictor->GetInputHandle(name);
|
||||||
|
auto shape_value = shape.at(name);
|
||||||
|
int shape_num = std::accumulate(shape_value.begin(), shape_value.end(), 1,
|
||||||
|
std::multiplies<int>());
|
||||||
|
tensor->Reshape(shape_value);
|
||||||
|
auto dtype = input_type[name];
|
||||||
|
switch (dtype) {
|
||||||
|
case paddle_infer::DataType::FLOAT32: {
|
||||||
|
std::vector<float> input_data(shape_num, 1.0);
|
||||||
|
tensor->CopyFromCpu(input_data.data());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case paddle_infer::DataType::INT32: {
|
||||||
|
std::vector<int> input_data(shape_num, 1);
|
||||||
|
tensor->CopyFromCpu(input_data.data());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case paddle_infer::DataType::INT64: {
|
||||||
|
std::vector<int64_t> input_data(shape_num, 1);
|
||||||
|
tensor->CopyFromCpu(input_data.data());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
FDASSERT(false, "Input data Paddle backend only supports FP32/INT32/INT64 currently.");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
predictor->Run();
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
} // namespace fastdeploy
|
} // namespace fastdeploy
|
||||||
|
@@ -44,6 +44,7 @@ struct PaddleBackendOption {
|
|||||||
bool enable_trt = false;
|
bool enable_trt = false;
|
||||||
#ifdef ENABLE_TRT_BACKEND
|
#ifdef ENABLE_TRT_BACKEND
|
||||||
TrtBackendOption trt_option;
|
TrtBackendOption trt_option;
|
||||||
|
bool collect_shape = false;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int mkldnn_cache_size = 1;
|
int mkldnn_cache_size = 1;
|
||||||
@@ -95,6 +96,15 @@ class PaddleBackend : public BaseBackend {
|
|||||||
std::vector<TensorInfo> GetOutputInfos() override;
|
std::vector<TensorInfo> GetOutputInfos() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
#ifdef ENABLE_TRT_BACKEND
|
||||||
|
void CollectShapeRun(paddle_infer::Predictor* predictor,
|
||||||
|
const std::map<std::string, std::vector<int>>& shape) const;
|
||||||
|
void GetDynamicShapeFromOption(const PaddleBackendOption& option,
|
||||||
|
std::map<std::string, std::vector<int>>* max_shape,
|
||||||
|
std::map<std::string, std::vector<int>>* min_shape,
|
||||||
|
std::map<std::string, std::vector<int>>* opt_shape) const;
|
||||||
|
void SetTRTDynamicShapeToConfig(const PaddleBackendOption& option);
|
||||||
|
#endif
|
||||||
paddle_infer::Config config_;
|
paddle_infer::Config config_;
|
||||||
std::shared_ptr<paddle_infer::Predictor> predictor_;
|
std::shared_ptr<paddle_infer::Predictor> predictor_;
|
||||||
std::vector<TensorInfo> inputs_desc_;
|
std::vector<TensorInfo> inputs_desc_;
|
||||||
|
@@ -29,16 +29,28 @@ void ShareTensorFromFDTensor(paddle_infer::Tensor* tensor,
|
|||||||
tensor->Reshape(shape);
|
tensor->Reshape(shape);
|
||||||
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
|
auto place = ConvertFDDeviceToPlace(fd_tensor.device);
|
||||||
if (fd_tensor.dtype == FDDataType::FP32) {
|
if (fd_tensor.dtype == FDDataType::FP32) {
|
||||||
tensor->ShareExternalData(static_cast<const float*>(fd_tensor.Data()),
|
if (place == paddle_infer::PlaceType::kGPU) {
|
||||||
|
tensor->ShareExternalData(static_cast<const float*>(fd_tensor.Data()),
|
||||||
shape, place);
|
shape, place);
|
||||||
|
} else {
|
||||||
|
tensor->CopyFromCpu(static_cast<const float*>(fd_tensor.Data()));
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
} else if (fd_tensor.dtype == FDDataType::INT32) {
|
} else if (fd_tensor.dtype == FDDataType::INT32) {
|
||||||
tensor->ShareExternalData(static_cast<const int32_t*>(fd_tensor.Data()),
|
if (place == paddle_infer::PlaceType::kGPU) {
|
||||||
|
tensor->ShareExternalData(static_cast<const int32_t*>(fd_tensor.Data()),
|
||||||
shape, place);
|
shape, place);
|
||||||
|
} else {
|
||||||
|
tensor->CopyFromCpu(static_cast<const int32_t*>(fd_tensor.Data()));
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
} else if (fd_tensor.dtype == FDDataType::INT64) {
|
} else if (fd_tensor.dtype == FDDataType::INT64) {
|
||||||
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
|
if (place == paddle_infer::PlaceType::kGPU) {
|
||||||
|
tensor->ShareExternalData(static_cast<const int64_t*>(fd_tensor.Data()),
|
||||||
shape, place);
|
shape, place);
|
||||||
|
} else {
|
||||||
|
tensor->CopyFromCpu(static_cast<const int64_t*>(fd_tensor.Data()));
|
||||||
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
|
FDASSERT(false, "Unexpected data type(%s) while infer with PaddleBackend.",
|
||||||
|
@@ -44,6 +44,8 @@ void BindRuntime(pybind11::module& m) {
|
|||||||
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
|
.def("enable_trt_fp16", &RuntimeOption::EnableTrtFP16)
|
||||||
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
|
.def("disable_trt_fp16", &RuntimeOption::DisableTrtFP16)
|
||||||
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
|
.def("set_trt_cache_file", &RuntimeOption::SetTrtCacheFile)
|
||||||
|
.def("enable_paddle_trt_collect_shape", &RuntimeOption::EnablePaddleTrtCollectShape)
|
||||||
|
.def("disable_paddle_trt_collect_shape", &RuntimeOption::DisablePaddleTrtCollectShape)
|
||||||
.def_readwrite("model_file", &RuntimeOption::model_file)
|
.def_readwrite("model_file", &RuntimeOption::model_file)
|
||||||
.def_readwrite("params_file", &RuntimeOption::params_file)
|
.def_readwrite("params_file", &RuntimeOption::params_file)
|
||||||
.def_readwrite("model_format", &RuntimeOption::model_format)
|
.def_readwrite("model_format", &RuntimeOption::model_format)
|
||||||
|
@@ -390,6 +390,14 @@ bool Runtime::Compile(std::vector<std::vector<FDTensor>>& prewarm_tensors,
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RuntimeOption::EnablePaddleTrtCollectShape() {
|
||||||
|
pd_collect_shape = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RuntimeOption::DisablePaddleTrtCollectShape() {
|
||||||
|
pd_collect_shape = false;
|
||||||
|
}
|
||||||
|
|
||||||
bool Runtime::Init(const RuntimeOption& _option) {
|
bool Runtime::Init(const RuntimeOption& _option) {
|
||||||
option = _option;
|
option = _option;
|
||||||
if (option.model_format == ModelFormat::AUTOREC) {
|
if (option.model_format == ModelFormat::AUTOREC) {
|
||||||
@@ -498,6 +506,7 @@ void Runtime::CreatePaddleBackend() {
|
|||||||
#ifdef ENABLE_TRT_BACKEND
|
#ifdef ENABLE_TRT_BACKEND
|
||||||
if (pd_option.use_gpu && option.pd_enable_trt) {
|
if (pd_option.use_gpu && option.pd_enable_trt) {
|
||||||
pd_option.enable_trt = true;
|
pd_option.enable_trt = true;
|
||||||
|
pd_option.collect_shape = option.pd_collect_shape;
|
||||||
auto trt_option = TrtBackendOption();
|
auto trt_option = TrtBackendOption();
|
||||||
trt_option.gpu_id = option.device_id;
|
trt_option.gpu_id = option.device_id;
|
||||||
trt_option.enable_fp16 = option.trt_enable_fp16;
|
trt_option.enable_fp16 = option.trt_enable_fp16;
|
||||||
|
@@ -204,6 +204,17 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
|||||||
*/
|
*/
|
||||||
void SetTrtCacheFile(const std::string& cache_file_path);
|
void SetTrtCacheFile(const std::string& cache_file_path);
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Enable to collect shape in paddle trt backend
|
||||||
|
*/
|
||||||
|
void EnablePaddleTrtCollectShape();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Disable to collect shape in paddle trt backend
|
||||||
|
*/
|
||||||
|
void DisablePaddleTrtCollectShape();
|
||||||
|
|
||||||
Backend backend = Backend::UNKNOWN;
|
Backend backend = Backend::UNKNOWN;
|
||||||
// for cpu inference and preprocess
|
// for cpu inference and preprocess
|
||||||
// default will let the backend choose their own default value
|
// default will let the backend choose their own default value
|
||||||
@@ -225,6 +236,7 @@ struct FASTDEPLOY_DECL RuntimeOption {
|
|||||||
bool pd_enable_mkldnn = true;
|
bool pd_enable_mkldnn = true;
|
||||||
bool pd_enable_log_info = false;
|
bool pd_enable_log_info = false;
|
||||||
bool pd_enable_trt = false;
|
bool pd_enable_trt = false;
|
||||||
|
bool pd_collect_shape = false;
|
||||||
int pd_mkldnn_cache_size = 1;
|
int pd_mkldnn_cache_size = 1;
|
||||||
std::vector<std::string> pd_delete_pass_names;
|
std::vector<std::string> pd_delete_pass_names;
|
||||||
|
|
||||||
|
74
fastdeploy/utils/path.h
Normal file
74
fastdeploy/utils/path.h
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <fstream>
|
||||||
|
#ifdef _MSC_VER
|
||||||
|
#define PATH_SEP "\\"
|
||||||
|
#else
|
||||||
|
#define PATH_SEP "/"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace fastdeploy {
|
||||||
|
|
||||||
|
inline std::string PathJoin(const std::vector<std::string>& paths,
|
||||||
|
const std::string& sep = PATH_SEP) {
|
||||||
|
if (paths.size() == 1) {
|
||||||
|
return paths[0];
|
||||||
|
}
|
||||||
|
std::string filepath = "";
|
||||||
|
for (const auto& path : paths) {
|
||||||
|
if (filepath == "") {
|
||||||
|
filepath += path;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (path[0] == sep[0] || filepath.back() == sep[0]) {
|
||||||
|
filepath += path;
|
||||||
|
} else {
|
||||||
|
filepath += sep + path;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filepath;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string PathJoin(const std::string& folder,
|
||||||
|
const std::string& filename,
|
||||||
|
const std::string& sep = PATH_SEP) {
|
||||||
|
return PathJoin(std::vector<std::string>{folder, filename}, sep);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline std::string GetDirFromPath(const std::string& path) {
|
||||||
|
auto pos = path.find_last_of(PATH_SEP);
|
||||||
|
if (pos == std::string::npos) {
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
// The root path in UNIX systems
|
||||||
|
if (pos == 0) {
|
||||||
|
return "/";
|
||||||
|
}
|
||||||
|
return path.substr(0, pos);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool CheckFileExists(const std::string& path) {
|
||||||
|
std::fstream fin(path, std::ios::in);
|
||||||
|
if (!fin) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace fastdeploy
|
@@ -329,6 +329,12 @@ class RuntimeOption:
|
|||||||
"""
|
"""
|
||||||
return self._option.set_trt_max_workspace_size(trt_max_workspace_size)
|
return self._option.set_trt_max_workspace_size(trt_max_workspace_size)
|
||||||
|
|
||||||
|
def enable_paddle_trt_collect_shape(self):
|
||||||
|
return self._option.enable_paddle_trt_collect_shape()
|
||||||
|
|
||||||
|
def disable_paddle_trt_collect_shape(self):
|
||||||
|
return self._option.disable_paddle_trt_collect_shape()
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
attrs = dir(self._option)
|
attrs = dir(self._option)
|
||||||
message = "RuntimeOption(\n"
|
message = "RuntimeOption(\n"
|
||||||
|
@@ -26,7 +26,7 @@ def process_paddle_inference(paddle_inference_so_file):
|
|||||||
rpaths = [
|
rpaths = [
|
||||||
"$ORIGIN", "$ORIGIN/../../third_party/install/mkldnn/lib/",
|
"$ORIGIN", "$ORIGIN/../../third_party/install/mkldnn/lib/",
|
||||||
"$ORIGIN/../../third_party/install/mklml/lib/",
|
"$ORIGIN/../../third_party/install/mklml/lib/",
|
||||||
"$ORIGIN/../../../tensorrt/lib"
|
"$ORIGIN/../../../tensorrt/lib/"
|
||||||
]
|
]
|
||||||
|
|
||||||
patchelf_exe = os.getenv("PATCHELF_EXE", "patchelf")
|
patchelf_exe = os.getenv("PATCHELF_EXE", "patchelf")
|
||||||
|
Reference in New Issue
Block a user