[Bug] Fix build error (#2112)

Fix build paddle2onnx error
This commit is contained in:
Jason
2023-07-16 19:49:50 -07:00
committed by GitHub
parent 681ccc4c24
commit f413e0263b
32 changed files with 597 additions and 176 deletions

View File

@@ -129,6 +129,12 @@ def arg_parser():
type=ast.literal_eval, type=ast.literal_eval,
default=False, default=False,
help="Whether export FP16 model for ORT-GPU, default False") help="Whether export FP16 model for ORT-GPU, default False")
parser.add_argument(
"--custom_ops",
type=_text_type,
default="{}",
help="Ops that needs to be converted to custom op, e.g --custom_ops '{\"paddle_op\":\"onnx_op\"}', default {}"
)
return parser return parser
@@ -144,12 +150,14 @@ def c_paddle_to_onnx(model_file,
deploy_backend="onnxruntime", deploy_backend="onnxruntime",
calibration_file="", calibration_file="",
external_file="", external_file="",
export_fp16_model=False): export_fp16_model=False,
custom_ops={}):
import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o import paddle2onnx.paddle2onnx_cpp2py_export as c_p2o
onnx_model_str = c_p2o.export( onnx_model_str = c_p2o.export(
model_file, params_file, opset_version, auto_upgrade_opset, verbose, model_file, params_file, opset_version, auto_upgrade_opset, verbose,
enable_onnx_checker, enable_experimental_op, enable_optimize, {}, enable_onnx_checker, enable_experimental_op, enable_optimize,
deploy_backend, calibration_file, external_file, export_fp16_model) custom_ops, deploy_backend, calibration_file, external_file,
export_fp16_model)
if save_file is not None: if save_file is not None:
with open(save_file, "wb") as f: with open(save_file, "wb") as f:
f.write(onnx_model_str) f.write(onnx_model_str)
@@ -235,6 +243,8 @@ def main():
os.mkdir(base_path) os.mkdir(base_path)
external_file = os.path.join(base_path, args.external_filename) external_file = os.path.join(base_path, args.external_filename)
custom_ops_dict = eval(args.custom_ops)
calibration_file = args.save_calibration_file calibration_file = args.save_calibration_file
c_paddle_to_onnx( c_paddle_to_onnx(
model_file=model_file, model_file=model_file,
@@ -249,7 +259,8 @@ def main():
deploy_backend=args.deploy_backend, deploy_backend=args.deploy_backend,
calibration_file=calibration_file, calibration_file=calibration_file,
external_file=external_file, external_file=external_file,
export_fp16_model=args.export_fp16_model) export_fp16_model=args.export_fp16_model,
custom_ops=custom_ops_dict)
logging.info("===============Make PaddlePaddle Better!================") logging.info("===============Make PaddlePaddle Better!================")
logging.info("A little survey: https://iwenjuan.baidu.com/?code=r8hu2s") logging.info("A little survey: https://iwenjuan.baidu.com/?code=r8hu2s")
return return

View File

@@ -137,7 +137,8 @@ PADDLE2ONNX_DECL bool Export(
bool enable_onnx_checker, bool enable_experimental_op, bool enable_optimize, bool enable_onnx_checker, bool enable_experimental_op, bool enable_optimize,
CustomOp* ops, int op_count, const char* deploy_backend, CustomOp* ops, int op_count, const char* deploy_backend,
char** calibration_cache, int* calibration_size, const char* external_file, char** calibration_cache, int* calibration_size, const char* external_file,
bool* save_external, bool export_fp16_model) { bool* save_external, bool export_fp16_model, char** disable_fp16_op_types,
int disable_fp16_op_types_count) {
auto parser = PaddleParser(); auto parser = PaddleParser();
P2OLogger(verbose) << "Start to parsing Paddle model..." << std::endl; P2OLogger(verbose) << "Start to parsing Paddle model..." << std::endl;
if (!parser.Init(model, params)) { if (!parser.Init(model, params)) {
@@ -158,12 +159,20 @@ PADDLE2ONNX_DECL bool Export(
me.custom_ops[op_name] = export_op_name; me.custom_ops[op_name] = export_op_name;
} }
} }
// Add disabled fp16 op information
std::vector<std::string> disable_op_types;
if (disable_fp16_op_types != nullptr && disable_fp16_op_types_count > 0) {
for (int i = 0; i < disable_fp16_op_types_count; ++i) {
std::string disable_op_type(disable_fp16_op_types[i],
strlen(disable_fp16_op_types[i]));
disable_op_types.push_back(disable_op_type);
}
}
std::string calibration_str; std::string calibration_str;
std::string result = me.Run( std::string result = me.Run(
parser, opset_version, auto_upgrade_opset, verbose, enable_onnx_checker, parser, opset_version, auto_upgrade_opset, verbose, enable_onnx_checker,
enable_experimental_op, enable_optimize, deploy_backend, &calibration_str, enable_experimental_op, enable_optimize, deploy_backend, &calibration_str,
external_file, save_external, export_fp16_model); external_file, save_external, export_fp16_model, disable_op_types);
if (result.empty()) { if (result.empty()) {
P2OLogger(verbose) << "The exported ONNX model is invalid!" << std::endl; P2OLogger(verbose) << "The exported ONNX model is invalid!" << std::endl;
return false; return false;
@@ -193,7 +202,8 @@ PADDLE2ONNX_DECL bool Export(
bool enable_experimental_op, bool enable_optimize, CustomOp* ops, bool enable_experimental_op, bool enable_optimize, CustomOp* ops,
int op_count, const char* deploy_backend, char** calibration_cache, int op_count, const char* deploy_backend, char** calibration_cache,
int* calibration_size, const char* external_file, bool* save_external, int* calibration_size, const char* external_file, bool* save_external,
bool export_fp16_model) { bool export_fp16_model, char** disable_fp16_op_types,
int disable_fp16_op_types_count) {
auto parser = PaddleParser(); auto parser = PaddleParser();
P2OLogger(verbose) << "Start to parsing Paddle model..." << std::endl; P2OLogger(verbose) << "Start to parsing Paddle model..." << std::endl;
if (!parser.Init(model_buffer, model_size, params_buffer, params_size)) { if (!parser.Init(model_buffer, model_size, params_buffer, params_size)) {
@@ -214,11 +224,20 @@ PADDLE2ONNX_DECL bool Export(
me.custom_ops[op_name] = export_op_name; me.custom_ops[op_name] = export_op_name;
} }
} }
// Add disabled fp16 op information
std::vector<std::string> disable_op_types;
if (disable_fp16_op_types != nullptr && disable_fp16_op_types_count > 0) {
for (int i = 0; i < disable_fp16_op_types_count; ++i) {
std::string disable_op_type(disable_fp16_op_types[i],
strlen(disable_fp16_op_types[i]));
disable_op_types.push_back(disable_op_type);
}
}
std::string calibration_str; std::string calibration_str;
std::string result = me.Run( std::string result = me.Run(
parser, opset_version, auto_upgrade_opset, verbose, enable_onnx_checker, parser, opset_version, auto_upgrade_opset, verbose, enable_onnx_checker,
enable_experimental_op, enable_optimize, deploy_backend, &calibration_str, enable_experimental_op, enable_optimize, deploy_backend, &calibration_str,
external_file, save_external, export_fp16_model); external_file, save_external, export_fp16_model, disable_op_types);
if (result.empty()) { if (result.empty()) {
P2OLogger(verbose) << "The exported ONNX model is invalid!" << std::endl; P2OLogger(verbose) << "The exported ONNX model is invalid!" << std::endl;
return false; return false;

6
paddle2onnx/converter.h Executable file → Normal file
View File

@@ -57,7 +57,8 @@ PADDLE2ONNX_DECL bool Export(
const char* deploy_backend = "onnxruntime", const char* deploy_backend = "onnxruntime",
char** calibration_cache = nullptr, int* calibration_size = 0, char** calibration_cache = nullptr, int* calibration_size = 0,
const char* external_file = "", bool* save_external = nullptr, const char* external_file = "", bool* save_external = nullptr,
bool export_fp16_model = false); bool export_fp16_model = false, char** disable_fp16_op_types = nullptr,
int disable_fp16_op_types_count = 0);
PADDLE2ONNX_DECL bool Export( PADDLE2ONNX_DECL bool Export(
const void* model_buffer, int64_t model_size, const void* params_buffer, const void* model_buffer, int64_t model_size, const void* params_buffer,
@@ -68,7 +69,8 @@ PADDLE2ONNX_DECL bool Export(
const char* deploy_backend = "onnxruntime", const char* deploy_backend = "onnxruntime",
char** calibration_cache = nullptr, int* calibration_size = 0, char** calibration_cache = nullptr, int* calibration_size = 0,
const char* external_file = "", bool* save_external = nullptr, const char* external_file = "", bool* save_external = nullptr,
bool export_fp16_model = false); bool export_fp16_model = false, char** disable_fp16_op_types = nullptr,
int disable_fp16_op_types_count = 0);
// Following are inside usage, will remove it maybe // Following are inside usage, will remove it maybe
struct PADDLE2ONNX_DECL ModelTensorInfo { struct PADDLE2ONNX_DECL ModelTensorInfo {

View File

@@ -186,7 +186,7 @@ void SwishMapper::Opset7() {
auto output_info = GetOutput("Out"); auto output_info = GetOutput("Out");
std::string beta_node = std::string beta_node =
helper_->Constant({1}, GetOnnxDtype(input_info[0].dtype), beta_); helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), beta_);
// TODO(jiangjiajun) eliminate multiply with a constant of value 1 // TODO(jiangjiajun) eliminate multiply with a constant of value 1
// TODO(jiangjiajun) eliminate add with a constant of value 0 // TODO(jiangjiajun) eliminate add with a constant of value 0
auto beta_x_node = helper_->MakeNode("Mul", {input_info[0].name, beta_node}); auto beta_x_node = helper_->MakeNode("Mul", {input_info[0].name, beta_node});
@@ -200,9 +200,9 @@ void HardSwishMapper::Opset7() {
auto output_info = GetOutput("Out"); auto output_info = GetOutput("Out");
std::string scale_node = std::string scale_node =
helper_->Constant({1}, GetOnnxDtype(input_info[0].dtype), scale_); helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), scale_);
std::string offset_node = std::string offset_node =
helper_->Constant({1}, GetOnnxDtype(input_info[0].dtype), offset_); helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), offset_);
auto add_node = helper_->MakeNode("Add", {input_info[0].name, offset_node}); auto add_node = helper_->MakeNode("Add", {input_info[0].name, offset_node});
auto clip_node = auto clip_node =
@@ -239,11 +239,11 @@ void GeluMapper::Opset9() {
double scale_value = 0.5; double scale_value = 0.5;
double const_1_value = 1.0; double const_1_value = 1.0;
auto sqrt_2 = auto sqrt_2 =
helper_->Constant({1}, ONNX_NAMESPACE::TensorProto::FLOAT, sqrt_2_value); helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, sqrt_2_value);
auto scale = auto scale =
helper_->Constant({1}, ONNX_NAMESPACE::TensorProto::FLOAT, scale_value); helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, scale_value);
auto const_1 = auto const_1 =
helper_->Constant({1}, ONNX_NAMESPACE::TensorProto::FLOAT, const_1_value); helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, const_1_value);
auto input_name = helper_->AutoCast(input_info[0].name, input_info[0].dtype, auto input_name = helper_->AutoCast(input_info[0].name, input_info[0].dtype,
P2ODataType::FP32); P2ODataType::FP32);
@@ -268,26 +268,34 @@ void GeluMapper::Opset9() {
void SoftMaxMapper::Opset7() { void SoftMaxMapper::Opset7() {
auto input_info = GetInput("X"); auto input_info = GetInput("X");
auto output_info = GetOutput("Out"); auto output_info = GetOutput("Out");
if (axis_ < 0) { if (input_info[0].Rank() == 0) {
axis_ = axis_ + output_info[0].Rank(); auto unsqueeze = helper_->Unsqueeze(input_info[0].name, {0});
} auto node = helper_->MakeNode("Softmax", {unsqueeze});
if (axis_ == output_info[0].Rank() - 1) { AddAttribute(node, "axis", static_cast<int64_t>(0));
auto node = helper_->MakeNode("Softmax", {input_info[0].name}, helper_->Squeeze(node->output(0), output_info[0].name, {0});
{output_info[0].name});
AddAttribute(node, "axis", axis_);
} else { } else {
std::vector<int64_t> perm = Arange(0, output_info[0].Rank()); if (axis_ < 0) {
perm[output_info[0].Rank() - 1] = axis_; axis_ = axis_ + output_info[0].Rank();
perm[axis_] = output_info[0].Rank() - 1; }
auto transpose_node = helper_->MakeNode("Transpose", {input_info[0].name}); if (axis_ == output_info[0].Rank() - 1) {
AddAttribute(transpose_node, "perm", perm); auto node = helper_->MakeNode("Softmax", {input_info[0].name},
auto softmax_node = {output_info[0].name});
helper_->MakeNode("Softmax", {transpose_node->output(0)}); AddAttribute(node, "axis", axis_);
int64_t axis_last = -1; } else {
AddAttribute(softmax_node, "axis", axis_last); std::vector<int64_t> perm = Arange(0, output_info[0].Rank());
auto transpose_node_last = helper_->MakeNode( perm[output_info[0].Rank() - 1] = axis_;
"Transpose", {softmax_node->output(0)}, {output_info[0].name}); perm[axis_] = output_info[0].Rank() - 1;
AddAttribute(transpose_node_last, "perm", perm); auto transpose_node =
helper_->MakeNode("Transpose", {input_info[0].name});
AddAttribute(transpose_node, "perm", perm);
auto softmax_node =
helper_->MakeNode("Softmax", {transpose_node->output(0)});
int64_t axis_last = -1;
AddAttribute(softmax_node, "axis", axis_last);
auto transpose_node_last = helper_->MakeNode(
"Transpose", {softmax_node->output(0)}, {output_info[0].name});
AddAttribute(transpose_node_last, "perm", perm);
}
} }
} }
@@ -296,9 +304,16 @@ void SoftMaxMapper::Opset13() {
GetAttr("axis", &axis); GetAttr("axis", &axis);
auto input_info = GetInput("X"); auto input_info = GetInput("X");
auto output_info = GetOutput("Out"); auto output_info = GetOutput("Out");
auto node = if (input_info[0].Rank() == 0) {
helper_->MakeNode("Softmax", {input_info[0].name}, {output_info[0].name}); auto unsqueeze = helper_->Unsqueeze(input_info[0].name, {0});
AddAttribute(node, "axis", axis); auto node = helper_->MakeNode("Softmax", {unsqueeze});
AddAttribute(node, "axis", static_cast<int64_t>(0));
helper_->Squeeze(node->output(0), output_info[0].name, {0});
} else {
auto node = helper_->MakeNode("Softmax", {input_info[0].name},
{output_info[0].name});
AddAttribute(node, "axis", axis);
}
} }
void BReluMapper::Opset7() { void BReluMapper::Opset7() {
@@ -357,7 +372,6 @@ void SizeMapper::Opset7() {
auto out_info = GetOutput("Out"); auto out_info = GetOutput("Out");
auto output = auto output =
helper_->MakeNode("Size", {GetInput("Input")[0].name})->output(0); helper_->MakeNode("Size", {GetInput("Input")[0].name})->output(0);
output = helper_->Reshape(output, {-1});
output = helper_->AutoCast(output, out_info[0].name, P2ODataType::INT64, output = helper_->AutoCast(output, out_info[0].name, P2ODataType::INT64,
out_info[0].dtype); out_info[0].dtype);
} }
@@ -382,21 +396,28 @@ void LogSigmoidMapper::Opset7() {
void LogSoftmaxMapper::Opset7() { void LogSoftmaxMapper::Opset7() {
auto input_info = GetInput("X"); auto input_info = GetInput("X");
auto axis = axis_; auto axis = axis_;
if (axis < 0) { if (input_info[0].Rank() == 0) {
axis += input_info[0].Rank(); auto unsqueeze = helper_->Unsqueeze(input_info[0].name, {0});
} auto node = helper_->MakeNode("LogSoftmax", {unsqueeze});
if (axis == input_info[0].Rank() - 1) { AddAttribute(node, "axis", static_cast<int64_t>(0));
auto node = helper_->MakeNode("LogSoftmax", {input_info[0].name}, helper_->Squeeze(node->output(0), GetOutput("Out")[0].name, {0});
{GetOutput("Out")[0].name});
AddAttribute(node, "axis", axis);
} else { } else {
auto perm = Arange(0, input_info[0].Rank()); if (axis < 0) {
perm[input_info[0].Rank() - 1] = axis; axis += input_info[0].Rank();
perm[axis] = input_info[0].Rank() - 1; }
auto output = helper_->Transpose(input_info[0].name, perm); if (axis == input_info[0].Rank() - 1) {
auto node = helper_->MakeNode("LogSoftmax", {output}); auto node = helper_->MakeNode("LogSoftmax", {input_info[0].name},
AddAttribute(node, "axis", int64_t(-1)); {GetOutput("Out")[0].name});
helper_->Transpose(node->output(0), GetOutput("Out")[0].name, perm); AddAttribute(node, "axis", axis);
} else {
auto perm = Arange(0, input_info[0].Rank());
perm[input_info[0].Rank() - 1] = axis;
perm[axis] = input_info[0].Rank() - 1;
auto output = helper_->Transpose(input_info[0].name, perm);
auto node = helper_->MakeNode("LogSoftmax", {output});
AddAttribute(node, "axis", int64_t(-1));
helper_->Transpose(node->output(0), GetOutput("Out")[0].name, perm);
}
} }
} }
@@ -420,7 +441,7 @@ void ThresholdedReluMapper::Opset10() {
void Log1PMapper::Opset7() { void Log1PMapper::Opset7() {
auto x_info = GetInput("X"); auto x_info = GetInput("X");
auto out_info = GetOutput("Out"); auto out_info = GetOutput("Out");
auto one = helper_->Constant({1}, GetOnnxDtype(x_info[0].dtype), float(1.0)); auto one = helper_->Constant({}, GetOnnxDtype(x_info[0].dtype), float(1.0));
auto input = helper_->MakeNode("Add", {x_info[0].name, one})->output(0); auto input = helper_->MakeNode("Add", {x_info[0].name, one})->output(0);
helper_->MakeNode("Log", {input}, {out_info[0].name}); helper_->MakeNode("Log", {input}, {out_info[0].name});
} }
@@ -429,7 +450,7 @@ void Log2Mapper::Opset7() {
auto x_info = GetInput("X"); auto x_info = GetInput("X");
auto out_info = GetOutput("Out"); auto out_info = GetOutput("Out");
double ln2 = 0.693147180559945309; double ln2 = 0.693147180559945309;
auto ln2_tensor = helper_->Constant({1}, GetOnnxDtype(x_info[0].dtype), ln2); auto ln2_tensor = helper_->Constant({}, GetOnnxDtype(x_info[0].dtype), ln2);
auto output = helper_->MakeNode("Log", {x_info[0].name})->output(0); auto output = helper_->MakeNode("Log", {x_info[0].name})->output(0);
helper_->MakeNode("Div", {output, ln2_tensor}, {out_info[0].name}); helper_->MakeNode("Div", {output, ln2_tensor}, {out_info[0].name});
} }
@@ -438,8 +459,7 @@ void Log10Mapper::Opset7() {
auto x_info = GetInput("X"); auto x_info = GetInput("X");
auto out_info = GetOutput("Out"); auto out_info = GetOutput("Out");
double ln10 = 2.30258509299404568401; double ln10 = 2.30258509299404568401;
auto ln10_tensor = auto ln10_tensor = helper_->Constant({}, GetOnnxDtype(x_info[0].dtype), ln10);
helper_->Constant({1}, GetOnnxDtype(x_info[0].dtype), ln10);
auto output = helper_->MakeNode("Log", {x_info[0].name})->output(0); auto output = helper_->MakeNode("Log", {x_info[0].name})->output(0);
helper_->MakeNode("Div", {output, ln10_tensor}, {out_info[0].name}); helper_->MakeNode("Div", {output, ln10_tensor}, {out_info[0].name});
} }

3
paddle2onnx/mapper/elementwise.cc Normal file → Executable file
View File

@@ -108,9 +108,8 @@ void ElementWiseModMapper::Opset10() {
auto abs_y_node = helper_->MakeNode("Abs", {input_y_info[0].name}); auto abs_y_node = helper_->MakeNode("Abs", {input_y_info[0].name});
auto dtype = input_y_info[0].dtype; auto dtype = input_y_info[0].dtype;
std::vector<float> val_0 = {0.0};
std::string zero_node = helper_->Constant(GetOnnxDtype(dtype), val_0); std::string zero_node = helper_->Constant({}, GetOnnxDtype(dtype), 0.0);
auto mod_node = auto mod_node =
helper_->MakeNode("Mod", {abs_x_node->output(0), abs_y_node->output(0)}); helper_->MakeNode("Mod", {abs_x_node->output(0), abs_y_node->output(0)});

View File

@@ -74,6 +74,80 @@ void ModelExporter::ExportInputOutputs(
} }
} }
void ModelExporter::CovertCustomOps(const PaddleParser& parser,
OnnxHelper* helper, int64_t block_id,
int64_t op_id) {
auto op = parser.GetOpDesc(block_id, op_id);
std::vector<std::string> input_strs;
for (auto i_index = 0; i_index < op.inputs_size(); i_index++) {
auto input = op.inputs(i_index);
std::string parameter = input.parameter();
if (parser.OpHasInput(block_id, op_id, parameter)) {
auto input_info = parser.GetOpInput(block_id, op_id, parameter);
for (auto input : input_info) {
input_strs.push_back(input.name);
helper->MakeValueInfo(input.name, input.dtype, input.shape);
}
}
}
std::vector<std::string> output_strs;
for (auto o_index = 0; o_index < op.outputs_size(); o_index++) {
auto output = op.outputs(o_index);
std::string parameter = output.parameter();
if (parser.OpHasOutput(block_id, op_id, parameter)) {
auto output_info = parser.GetOpOutput(block_id, op_id, parameter);
for (auto output : output_info) {
output_strs.push_back(output.name);
helper->MakeValueInfo(output.name, output.dtype, output.shape);
}
}
}
auto node = helper->MakeNode(custom_ops[op.type()], input_strs, output_strs);
node->set_domain("Paddle");
for (auto attr_index = 0; attr_index < op.attrs_size(); attr_index++) {
auto attr = op.attrs(attr_index);
std::string attr_name = attr.name();
if (attr_name == "op_callstack") {
continue;
}
if (attr.has_i() || attr.has_l()) {
int64_t val;
parser.GetOpAttr(op, attr_name, &val);
AddAttribute(node, attr_name, val);
} else if (attr.has_f()) {
float val;
parser.GetOpAttr(op, attr_name, &val);
AddAttribute(node, attr_name, val);
} else if (attr.has_b()) {
bool val;
parser.GetOpAttr(op, attr_name, &val);
AddAttribute(node, attr_name, static_cast<int64_t>(val));
} else if (attr.has_s()) {
std::string val;
parser.GetOpAttr(op, attr_name, &val);
AddAttribute(node, attr_name, val);
} else if (attr.ints_size() > 0 || attr.longs_size() > 0) {
std::vector<int64_t> vec;
parser.GetOpAttr(op, attr_name, &vec);
AddAttribute(node, attr_name, vec);
} else if (attr.floats_size() > 0) {
std::vector<float> vec;
parser.GetOpAttr(op, attr_name, &vec);
AddAttribute(node, attr_name, vec);
} else if (attr.float64s_size() > 0) {
std::vector<double> vec;
parser.GetOpAttr(op, attr_name, &vec);
std::vector<float> fp32_vec;
for (auto val : vec) {
fp32_vec.push_back(static_cast<float>(val));
}
AddAttribute(node, attr_name, fp32_vec);
}
}
P2OLogger(true) << op.type() << " is exported as custom operator: "
<< custom_ops[op.type()] << std::endl;
}
void ModelExporter::ExportOp(const PaddleParser& parser, OnnxHelper* helper, void ModelExporter::ExportOp(const PaddleParser& parser, OnnxHelper* helper,
int32_t opset_version, int64_t block_id, int32_t opset_version, int64_t block_id,
int64_t op_id, bool verbose) { int64_t op_id, bool verbose) {
@@ -87,20 +161,24 @@ void ModelExporter::ExportOp(const PaddleParser& parser, OnnxHelper* helper,
return ExportLoop(parser, helper, opset_version, block_id, op_id, verbose); return ExportLoop(parser, helper, opset_version, block_id, op_id, verbose);
} }
auto mapper = MapperHelper::Get()->CreateMapper(op.type(), parser, helper, if (MapperHelper::Get()->IsRegistered(op.type())) {
block_id, op_id); auto mapper = MapperHelper::Get()->CreateMapper(op.type(), parser, helper,
mapper->deploy_backend = _deploy_backend; block_id, op_id);
mapper->deploy_backend = _deploy_backend;
#ifdef PADDLE2ONNX_DEBUG #ifdef PADDLE2ONNX_DEBUG
P2OLogger(true) << "Mapper Name: " << mapper->Name() << std::endl; P2OLogger(true) << "Mapper Name: " << mapper->Name() << std::endl;
#endif #endif
// Some operators will export as custom operator // Some operators will export as custom operator
auto iter = custom_ops.find(op.type()); auto iter = custom_ops.find(op.type());
if (iter != custom_ops.end()) { if (iter != custom_ops.end()) {
mapper->export_as_custom_op = true; mapper->export_as_custom_op = true;
mapper->custom_op_name = iter->second; mapper->custom_op_name = iter->second;
}
mapper->Run();
delete mapper;
} else if (custom_ops.find(op.type()) != custom_ops.end()) {
CovertCustomOps(parser, helper, block_id, op_id);
} }
mapper->Run();
delete mapper;
#ifdef PADDLE2ONNX_DEBUG #ifdef PADDLE2ONNX_DEBUG
P2OLogger(true) << "---Converting operator: " << op.type() << " done---" P2OLogger(true) << "---Converting operator: " << op.type() << " done---"
@@ -252,7 +330,8 @@ std::string ModelExporter::Run(
bool verbose, bool enable_onnx_checker, bool enable_experimental_op, bool verbose, bool enable_onnx_checker, bool enable_experimental_op,
bool enable_optimize, const std::string& deploy_backend, bool enable_optimize, const std::string& deploy_backend,
std::string* calibration_cache, const std::string& external_file, std::string* calibration_cache, const std::string& external_file,
bool* save_external, bool export_fp16_model) { bool* save_external, bool export_fp16_model,
std::vector<std::string> disable_fp16_op_types) {
_deploy_backend = deploy_backend; _deploy_backend = deploy_backend;
_helper.SetOpsetVersion(opset_version); _helper.SetOpsetVersion(opset_version);
_total_ops_num = 0; _total_ops_num = 0;
@@ -383,6 +462,7 @@ std::string ModelExporter::Run(
P2OLogger(verbose) << "Convert FP32 ONNX model to FP16." << std::endl; P2OLogger(verbose) << "Convert FP32 ONNX model to FP16." << std::endl;
ConvertFp32ToFp16 convert; ConvertFp32ToFp16 convert;
convert.SetCustomOps(custom_ops); convert.SetCustomOps(custom_ops);
convert.AddDisabledOpTypes(disable_fp16_op_types);
convert.Convert(&onnx_model); convert.Convert(&onnx_model);
} }
@@ -429,6 +509,9 @@ bool ModelExporter::CheckIfOpSupported(const PaddleParser& parser,
} }
continue; continue;
} }
if (custom_ops.find(op.type()) != custom_ops.end()) {
continue;
}
if (!MapperHelper::Get()->IsRegistered(op.type())) { if (!MapperHelper::Get()->IsRegistered(op.type())) {
unsupported_ops->insert(op.type()); unsupported_ops->insert(op.type());
} else if (!enable_experimental_op) { } else if (!enable_experimental_op) {

7
paddle2onnx/mapper/exporter.h Executable file → Normal file
View File

@@ -66,7 +66,8 @@ struct ModelExporter {
void ExportLoop(const PaddleParser& parser, OnnxHelper* helper, void ExportLoop(const PaddleParser& parser, OnnxHelper* helper,
int32_t opset_version, int64_t block_id, int64_t op_id, int32_t opset_version, int64_t block_id, int64_t op_id,
bool verbose); bool verbose);
void CovertCustomOps(const PaddleParser& parser, OnnxHelper* helper,
int64_t block_id, int64_t op_id);
ONNX_NAMESPACE::ModelProto Optimize(const ONNX_NAMESPACE::ModelProto& model); ONNX_NAMESPACE::ModelProto Optimize(const ONNX_NAMESPACE::ModelProto& model);
public: public:
@@ -115,8 +116,8 @@ struct ModelExporter {
const std::string& deploy_backend = "onnxruntime", const std::string& deploy_backend = "onnxruntime",
std::string* calibration_cache = nullptr, std::string* calibration_cache = nullptr,
const std::string& external_file = "", const std::string& external_file = "",
bool* save_external = nullptr, bool* save_external = nullptr, bool export_fp16_model = false,
bool export_fp16_model = false); std::vector<std::string> disable_fp16_op_types = {});
}; };
} // namespace paddle2onnx } // namespace paddle2onnx

View File

@@ -55,9 +55,8 @@ void DropoutMapper::Opset7() {
} else { } else {
GetAttr("dropout_prob", &dropout_prob_); GetAttr("dropout_prob", &dropout_prob_);
} }
std::vector<float> value = {1 - dropout_prob_}; std::string scale_node = helper_->Constant(
std::string scale_node = {}, GetOnnxDtype(input_info[0].dtype), 1 - dropout_prob_);
helper_->Constant(GetOnnxDtype(input_info[0].dtype), value);
helper_->MakeNode("Mul", {input_info[0].name, scale_node}, helper_->MakeNode("Mul", {input_info[0].name, scale_node},
{output_info[0].name}); {output_info[0].name});
} }

View File

@@ -0,0 +1,54 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle2onnx/mapper/nn/pad.h"
namespace paddle2onnx {
REGISTER_MAPPER(pad, PadMapper)
std::vector<int64_t> PadMapper::ConvertPaddingParameter(
const std::vector<int64_t>& paddings) {
std::vector<int64_t> new_paddings(paddings.size(), 0);
Assert(paddings.size() % 2 == 0, "The size of padding should be even");
int64_t half_paddings_len = paddings.size() / 2;
for (auto i = 0; i < half_paddings_len; ++i) {
new_paddings[i] = paddings[2 * i];
new_paddings[i + half_paddings_len] = paddings[2 * i + 1];
}
return new_paddings;
}
void PadMapper::Opset7() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
auto node =
helper_->MakeNode("Pad", {input_info[0].name}, {output_info[0].name});
AddAttribute(node, "mode", "constant");
AddAttribute(node, "value", pad_value_);
AddAttribute(node, "pads", ConvertPaddingParameter(paddings_));
}
void PadMapper::Opset11() {
auto input_info = GetInput("X");
auto output_info = GetOutput("Out");
auto paddings = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64,
ConvertPaddingParameter(paddings_));
auto value =
helper_->Constant({}, GetOnnxDtype(input_info[0].dtype), pad_value_);
auto node = helper_->MakeNode("Pad", {input_info[0].name, paddings, value},
{output_info[0].name});
AddAttribute(node, "mode", "constant");
}
} // namespace paddle2onnx

View File

@@ -0,0 +1,41 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle2onnx/mapper/mapper.h"
namespace paddle2onnx {
class PadMapper : public Mapper {
public:
PadMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {
GetAttr("pad_value", &pad_value_);
GetAttr("paddings", &paddings_);
}
void Opset7();
void Opset11();
private:
std::vector<int64_t> ConvertPaddingParameter(
const std::vector<int64_t>& paddings);
std::vector<int64_t> paddings_;
float pad_value_;
};
} // namespace paddle2onnx

8
paddle2onnx/mapper/nn/pool2d.cc Normal file → Executable file
View File

@@ -156,11 +156,11 @@ void Pool2dMapper::NoAdaptivePool(const std::vector<TensorInfo>& input_info,
pads_[i] = copy[index[i]]; pads_[i] = copy[index[i]];
} }
} }
if (input_shape[2] > 0 && input_shape[2] + pads_[0] < k_size_[0]) { if (input_shape[2] > 0 && input_shape[2] + pads_[0] + pads_[2] < k_size_[0]) {
k_size_[0] = input_shape[2] + pads_[0]; k_size_[0] = input_shape[2] + pads_[0] + pads_[2];
} }
if (input_shape[3] > 0 && input_shape[3] + pads_[1] < k_size_[1]) { if (input_shape[3] > 0 && input_shape[3] + pads_[1] + pads_[3] < k_size_[1]) {
k_size_[1] = input_shape[3] + pads_[1]; k_size_[1] = input_shape[3] + pads_[1] + pads_[3];
} }
int64_t max_ksize = *std::max_element(std::begin(k_size_), std::end(k_size_)); int64_t max_ksize = *std::max_element(std::begin(k_size_), std::end(k_size_));

2
paddle2onnx/mapper/onnx_helper.cc Normal file → Executable file
View File

@@ -74,7 +74,7 @@ void AddAttribute(std::shared_ptr<ONNX_NAMESPACE::NodeProto> node,
for (auto& item : values) { for (auto& item : values) {
attr->add_floats(item); attr->add_floats(item);
} }
attr->set_type(ONNX_NAMESPACE::AttributeProto::FLOAT); attr->set_type(ONNX_NAMESPACE::AttributeProto::FLOATS);
} }
void AddAttribute(std::shared_ptr<ONNX_NAMESPACE::NodeProto> node, void AddAttribute(std::shared_ptr<ONNX_NAMESPACE::NodeProto> node,

36
paddle2onnx/mapper/quantize_helper.cc Executable file → Normal file
View File

@@ -174,6 +174,7 @@ void QuantizeModelProcessor::ProcessQuantizeModel(
// 6. use topo sort in nodes // 6. use topo sort in nodes
QuantizeInfoBroadcast(); QuantizeInfoBroadcast();
RemoveAllQuantizeOps(); RemoveAllQuantizeOps();
RemoveIdentityOp();
MergeConvAdd(); MergeConvAdd();
MergeConvBN(); MergeConvBN();
AddQDQForRKNN(); AddQDQForRKNN();
@@ -187,6 +188,19 @@ void QuantizeModelProcessor::ProcessQuantizeModel(
} }
} }
void QuantizeModelProcessor::RemoveIdentityOp() {
UpdateInputNameToNodes();
auto iter = nodes_->begin();
while (iter != nodes_->end()) {
auto node = *iter;
if (node->op_type() == "Identity" && !ConnectToOutput(node->output(0))) {
RemoveNodeByName(node->name());
} else {
iter++;
}
}
}
void QuantizeModelProcessor::AddQDQForRKNN() { void QuantizeModelProcessor::AddQDQForRKNN() {
UpdateInputNameToNodes(); UpdateInputNameToNodes();
supported_quantize_type_ = {"Abs", supported_quantize_type_ = {"Abs",
@@ -226,6 +240,7 @@ void QuantizeModelProcessor::AddQDQForRKNN() {
"Split", "Split",
"Sqrt", "Sqrt",
"Tan", "Tan",
"MatMul",
"Tanh"}; "Tanh"};
for (auto iter = nodes_->begin(); iter < nodes_->end(); iter++) { for (auto iter = nodes_->begin(); iter < nodes_->end(); iter++) {
auto node = *iter; auto node = *iter;
@@ -582,12 +597,11 @@ void QuantizeModelProcessor::AddQDQInModel(
std::vector<float> bias; std::vector<float> bias;
Assert(GetTensorByName(name, &bias), Assert(GetTensorByName(name, &bias),
"[QuantizeModelProcessor] Can not find bias value: " + name); "[QuantizeModelProcessor] Can not find bias value: " + name);
std::vector<int32_t> new_bias(scale.size(), 0); std::vector<int32_t> new_bias(bias.size(), 0);
for (int64_t i = 0; i < bias.size(); i++) { for (int64_t i = 0; i < bias.size(); i++) {
float scale_val = scale.size() == 1 ? scale[0] : scale[i]; float scale_val = scale.size() == 1 ? scale[0] : scale[i];
new_bias[i] = rint(bias[i] / scale_val); new_bias[i] = rint(bias[i] / scale_val);
} }
Weight updated_bias; Weight updated_bias;
std::vector<int64_t> bias_shape = {static_cast<int64_t>(new_bias.size())}; std::vector<int64_t> bias_shape = {static_cast<int64_t>(new_bias.size())};
updated_bias.set(P2ODataType::INT32, bias_shape, new_bias); updated_bias.set(P2ODataType::INT32, bias_shape, new_bias);
@@ -980,10 +994,22 @@ void QuantizeModelProcessor::RemoveAllQuantizeOps() {
continue; continue;
} }
std::string input_name = node->input(0); std::string input_name = node->input(0);
RemoveNodeByName(node->name()); RemoveNodeByName(node->name(), false);
std::string output_name = next_node_names[0]->output(0); std::string output_name = next_node_names[0]->output(0);
RemoveNodeByName(next_node_names[0]->name()); RemoveNodeByName(next_node_names[0]->name(), false);
ReplaceInputOfAllNodes(output_name, input_name); if (ConnectToOutput(output_name)) {
for (auto pre_iter = nodes_->begin(); pre_iter < nodes_->end();
pre_iter++) {
auto pre_node = *pre_iter;
for (size_t o_idex = 0; o_idex < pre_node->output_size(); ++o_idex) {
if (pre_node->output(o_idex) == input_name) {
pre_node->set_output(o_idex, output_name);
}
}
}
} else {
ReplaceInputOfAllNodes(output_name, input_name);
}
} }
} }

2
paddle2onnx/mapper/quantize_helper.h Normal file → Executable file
View File

@@ -84,6 +84,8 @@ struct QuantizeModelProcessor {
// Add QDQ for RKNN // Add QDQ for RKNN
void AddQDQForRKNN(); void AddQDQForRKNN();
void RemoveIdentityOp();
// Add quantize related op in model according to tensor names // Add quantize related op in model according to tensor names
void AddQDQInModel(const std::vector<std::string>& tensors_to_be_quantize); void AddQDQInModel(const std::vector<std::string>& tensors_to_be_quantize);

View File

@@ -36,13 +36,6 @@ void ArgMaxMapper::Opset7() {
input = helper_->Flatten(input_info[0].name); input = helper_->Flatten(input_info[0].name);
} }
// Make sure the output tensor has to be 1D-Tensor
bool need_unsqueeze = false;
if (flatten_ || input_info[0].shape.size() <= 1) {
if (!keepdims_) {
need_unsqueeze = true;
}
}
if (IsAttrVar("axis")) { if (IsAttrVar("axis")) {
auto axis_info = GetAttrVar("axis"); auto axis_info = GetAttrVar("axis");
std::vector<int64_t> temp; std::vector<int64_t> temp;
@@ -60,13 +53,17 @@ void ArgMaxMapper::Opset7() {
auto arg_node = helper_->MakeNode("ArgMax", {input}); auto arg_node = helper_->MakeNode("ArgMax", {input});
AddAttribute(arg_node, "axis", axis_); AddAttribute(arg_node, "axis", axis_);
AddAttribute(arg_node, "keepdims", static_cast<int64_t>(keepdims_)); AddAttribute(arg_node, "keepdims", static_cast<int64_t>(keepdims_));
if (!need_unsqueeze) { if (keepdims_) {
helper_->AutoCast(arg_node->output(0), output_info[0].name, std::vector<int64_t> shape(input_info[0].Rank(), 1);
P2ODataType::INT64, output_info[0].dtype); std::string out = arg_node->output(0);
} else { if (flatten_) {
auto out = helper_->Unsqueeze(arg_node->output(0), {0}); out = helper_->Reshape(arg_node->output(0), shape);
}
helper_->AutoCast(out, output_info[0].name, P2ODataType::INT64, helper_->AutoCast(out, output_info[0].name, P2ODataType::INT64,
output_info[0].dtype); output_info[0].dtype);
} else {
helper_->AutoCast(arg_node->output(0), output_info[0].name,
P2ODataType::INT64, output_info[0].dtype);
} }
} }

View File

@@ -36,13 +36,6 @@ void ArgMinMapper::Opset7() {
input = helper_->Flatten(input_info[0].name); input = helper_->Flatten(input_info[0].name);
} }
// Make sure the output tensor has to be 1D-Tensor
bool need_unsqueeze = false;
if (flatten_ || input_info[0].shape.size() <= 1) {
if (!keepdims_) {
need_unsqueeze = true;
}
}
if (IsAttrVar("axis")) { if (IsAttrVar("axis")) {
auto axis_info = GetAttrVar("axis"); auto axis_info = GetAttrVar("axis");
std::vector<int64_t> temp; std::vector<int64_t> temp;
@@ -51,6 +44,7 @@ void ArgMinMapper::Opset7() {
} else { } else {
GetAttr("axis", &axis_); GetAttr("axis", &axis_);
} }
if (input_info[0].dtype == P2ODataType::FP64) { if (input_info[0].dtype == P2ODataType::FP64) {
input = helper_->AutoCast(input, P2ODataType::FP64, P2ODataType::FP32); input = helper_->AutoCast(input, P2ODataType::FP64, P2ODataType::FP32);
} }
@@ -60,13 +54,17 @@ void ArgMinMapper::Opset7() {
auto arg_node = helper_->MakeNode("ArgMin", {input}); auto arg_node = helper_->MakeNode("ArgMin", {input});
AddAttribute(arg_node, "axis", axis_); AddAttribute(arg_node, "axis", axis_);
AddAttribute(arg_node, "keepdims", static_cast<int64_t>(keepdims_)); AddAttribute(arg_node, "keepdims", static_cast<int64_t>(keepdims_));
if (!need_unsqueeze) { if (keepdims_) {
helper_->AutoCast(arg_node->output(0), output_info[0].name, std::vector<int64_t> shape(input_info[0].Rank(), 1);
P2ODataType::INT64, output_info[0].dtype); std::string out = arg_node->output(0);
} else { if (flatten_) {
auto out = helper_->Unsqueeze(arg_node->output(0), {0}); out = helper_->Reshape(arg_node->output(0), shape);
}
helper_->AutoCast(out, output_info[0].name, P2ODataType::INT64, helper_->AutoCast(out, output_info[0].name, P2ODataType::INT64,
output_info[0].dtype); output_info[0].dtype);
} else {
helper_->AutoCast(arg_node->output(0), output_info[0].name,
P2ODataType::INT64, output_info[0].dtype);
} }
} }

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle2onnx/mapper/tensor/atan2.h"
#define M_PI 3.14159265358979323846 /* pi */
namespace paddle2onnx {
REGISTER_MAPPER(atan2, Atan2Mapper)
int32_t Atan2Mapper::GetMinOpset(bool verbose) {
if (GetInput("X1")[0].dtype == P2ODataType::INT32 ||
GetInput("X2")[0].dtype == P2ODataType::INT32 ||
GetInput("X1")[0].dtype == P2ODataType::INT64 ||
GetInput("X2")[0].dtype == P2ODataType::INT64) {
Error() << "The input dtype should be float32 or float64. " << std::endl;
return -1;
}
Logger(verbose, 9) << RequireOpset(9) << std::endl;
return 9;
}
void Atan2Mapper::Opset9() {
auto x_info = GetInput("X1");
auto y_info = GetInput("X2");
auto out_info = GetOutput("Out");
std::string input_x_name = x_info[0].name;
std::string input_y_name = y_info[0].name;
auto dtype = P2ODataType::FP32;
if (x_info[0].dtype == P2ODataType::FP64 ||
y_info[0].dtype == P2ODataType::FP64) {
input_x_name =
helper_->AutoCast(x_info[0].name, x_info[0].dtype, P2ODataType::FP32);
input_y_name =
helper_->AutoCast(y_info[0].name, y_info[0].dtype, P2ODataType::FP32);
}
auto div = helper_->MakeNode("Div", {input_x_name, input_y_name});
auto atan = helper_->MakeNode("Atan", {div->output(0)});
std::string zero_node =
helper_->Constant(GetOnnxDtype(dtype), std::vector<float>{0.0});
auto minus_node = helper_->MakeNode("Less", {input_y_name, zero_node});
std::string condition_node =
helper_->AutoCast(minus_node->output(0), dtype, P2ODataType::BOOL);
std::string pi_node =
helper_->Constant(GetOnnxDtype(dtype), std::vector<float>{M_PI});
auto sign_node = helper_->MakeNode("Sign", {input_x_name});
auto mul_node = helper_->MakeNode("Mul", {sign_node->output(0), pi_node});
auto where_node = helper_->MakeNode(
"Where", {condition_node, mul_node->output(0), zero_node});
auto add_node =
helper_->MakeNode("Add", {atan->output(0), where_node->output(0)});
helper_->AutoCast(add_node->output(0), out_info[0].name, dtype,
out_info[0].dtype);
}
} // namespace paddle2onnx

View File

@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle2onnx/mapper/mapper.h"
namespace paddle2onnx {
class Atan2Mapper : public Mapper {
public:
Atan2Mapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id)
: Mapper(p, helper, block_id, op_id) {}
void Opset9();
int32_t GetMinOpset(bool verbose = false);
};
} // namespace paddle2onnx

View File

@@ -25,17 +25,33 @@ int32_t CumsumMapper::GetMinOpset(bool verbose) {
void CumsumMapper::Opset11() { void CumsumMapper::Opset11() {
auto input_info = GetInput("X"); auto input_info = GetInput("X");
auto output_info = GetOutput("Out"); auto output_info = GetOutput("Out");
std::string axis_node; if (input_info[0].Rank() == 0) {
if (IsAttrVar("axis")) { auto axis_node = helper_->Constant({}, GetOnnxDtype(P2ODataType::INT64), 0);
auto axis_info = GetAttrVar("axis"); auto unsqueeze_node = helper_->Unsqueeze(input_info[0].name, {0});
axis_node = helper_->AutoCast(axis_info[0].name, axis_info[0].dtype, auto cumsum_node = helper_->MakeNode("CumSum", {unsqueeze_node, axis_node});
P2ODataType::INT64); if (flatten_) {
helper_->AutoCast(cumsum_node->output(0), output_info[0].name,
input_info[0].dtype, output_info[0].dtype);
} else {
helper_->Squeeze(cumsum_node->output(0), output_info[0].name, {0});
}
} else { } else {
GetAttr("axis", &axis_); std::string axis_node;
axis_node = helper_->Constant({1}, GetOnnxDtype(P2ODataType::INT64), axis_); if (IsAttrVar("axis")) {
auto axis_info = GetAttrVar("axis");
axis_node = helper_->AutoCast(axis_info[0].name, axis_info[0].dtype,
P2ODataType::INT64);
} else {
GetAttr("axis", &axis_);
axis_node =
helper_->Constant({}, GetOnnxDtype(P2ODataType::INT64), axis_);
}
std::string input_node = input_info[0].name;
if (flatten_) {
input_node = helper_->Reshape(input_info[0].name, {-1});
}
helper_->MakeNode("CumSum", {input_node, axis_node}, {output_info[0].name});
} }
helper_->MakeNode("CumSum", {input_info[0].name, axis_node},
{output_info[0].name});
} }
} // namespace paddle2onnx } // namespace paddle2onnx

View File

@@ -24,12 +24,15 @@ class CumsumMapper : public Mapper {
public: public:
CumsumMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id, CumsumMapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id) int64_t op_id)
: Mapper(p, helper, block_id, op_id) {} : Mapper(p, helper, block_id, op_id) {
GetAttr("flatten", &flatten_);
}
int32_t GetMinOpset(bool verbose = false); int32_t GetMinOpset(bool verbose = false);
void Opset11(); void Opset11();
private: private:
int64_t axis_; int64_t axis_;
bool flatten_;
}; };
} // namespace paddle2onnx } // namespace paddle2onnx

View File

@@ -33,7 +33,9 @@ int32_t FlipMapper::GetMinOpset(bool verbose) {
void FlipMapper::Opset7() { void FlipMapper::Opset7() {
auto input_info = GetInput("X"); auto input_info = GetInput("X");
auto output_info = GetOutput("Out"); auto output_info = GetOutput("Out");
if (input_info[0].Rank() == 0) {
helper_->MakeNode("Identity", {input_info[0].name}, {output_info[0].name});
}
std::string input_name = input_info[0].name; std::string input_name = input_info[0].name;
bool need_convert = false; bool need_convert = false;
if (input_info[0].dtype == P2ODataType::BOOL || if (input_info[0].dtype == P2ODataType::BOOL ||

View File

@@ -19,11 +19,14 @@ REGISTER_MAPPER(gaussian_random, GaussianRandomMapper)
int32_t GaussianRandomMapper::GetMinOpset(bool verbose) { int32_t GaussianRandomMapper::GetMinOpset(bool verbose) {
if (HasInput("ShapeTensor") && !IsConstantInput("ShapeTensor")) { if (HasInput("ShapeTensor") && !IsConstantInput("ShapeTensor")) {
Logger(verbose, 9) << "While ShapeTensor as input and it's not a constant tensor, " << RequireOpset(9) << std::endl; Logger(verbose, 9)
<< "While ShapeTensor as input and it's not a constant tensor, "
<< RequireOpset(9) << std::endl;
return 9; return 9;
} }
if (HasInput("ShapeTensorList")) { if (HasInput("ShapeTensorList")) {
Logger(verbose, 9) << "While ShapeTensorList as input, " << RequireOpset(9) << std::endl; Logger(verbose, 9) << "While ShapeTensorList as input, " << RequireOpset(9)
<< std::endl;
return 9; return 9;
} }
return 7; return 7;
@@ -36,13 +39,24 @@ void GaussianRandomMapper::Opset7() {
if (HasInput("ShapeTensor")) { if (HasInput("ShapeTensor")) {
if (!TryGetInputValue("ShapeTensor", &shape)) { if (!TryGetInputValue("ShapeTensor", &shape)) {
auto shape_info = GetInput("ShapeTensor"); auto shape_info = GetInput("ShapeTensor");
shape_tensor_name = helper_->AutoCast(shape_info[0].name, shape_info[0].dtype, P2ODataType::INT64); shape_tensor_name = helper_->AutoCast(
shape_info[0].name, shape_info[0].dtype, P2ODataType::INT64);
} }
} else if (HasInput("ShapeTensorList")) { } else if (HasInput("ShapeTensorList")) {
auto shape_info = GetInput("ShapeTensorList"); auto shape_info = GetInput("ShapeTensorList");
shape_tensor_name = helper_->ConcatIndices(shape_info); shape_tensor_name = helper_->ConcatIndices(shape_info);
} else { } else {
shape.assign(shape_.begin(), shape_.end()); shape.assign(shape_.begin(), shape_.end());
}
if (out_info[0].Rank() == 0) {
auto node = helper_->MakeNode("RandomNormal", {});
AddAttribute(node, "dtype", GetOnnxDtype(out_info[0].dtype));
AddAttribute(node, "mean", mean_);
AddAttribute(node, "scale", std_);
AddAttribute(node, "shape", std::vector<int64_t>(1, 1));
AddAttribute(node, "seed", static_cast<float>(seed_));
helper_->Squeeze(node->output(0), {out_info[0].name}, {0});
return;
} }
if (shape.size() > 0) { if (shape.size() > 0) {
auto node = helper_->MakeNode("RandomNormal", {}, {out_info[0].name}); auto node = helper_->MakeNode("RandomNormal", {}, {out_info[0].name});
@@ -52,8 +66,10 @@ void GaussianRandomMapper::Opset7() {
AddAttribute(node, "shape", shape_); AddAttribute(node, "shape", shape_);
AddAttribute(node, "seed", static_cast<float>(seed_)); AddAttribute(node, "seed", static_cast<float>(seed_));
} else { } else {
auto tensor = helper_->ConstOfShape(shape_tensor_name, GetOnnxDtype(out_info[0].dtype), float(0)); auto tensor = helper_->ConstOfShape(
auto node = helper_->MakeNode("RandomNormalLike", {tensor}, {out_info[0].name}); shape_tensor_name, GetOnnxDtype(out_info[0].dtype), float(0));
auto node =
helper_->MakeNode("RandomNormalLike", {tensor}, {out_info[0].name});
AddAttribute(node, "dtype", GetOnnxDtype(out_info[0].dtype)); AddAttribute(node, "dtype", GetOnnxDtype(out_info[0].dtype));
AddAttribute(node, "mean", mean_); AddAttribute(node, "mean", mean_);
AddAttribute(node, "scale", std_); AddAttribute(node, "scale", std_);

View File

@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle2onnx/mapper/tensor/pow.h" #include "paddle2onnx/mapper/tensor/pow.h"
#include <unordered_set> #include <unordered_set>
namespace paddle2onnx { namespace paddle2onnx {
@@ -22,16 +23,17 @@ void PowMapper::Opset7() {
auto input_info = GetInput("X"); auto input_info = GetInput("X");
auto output_info = GetOutput("Out"); auto output_info = GetOutput("Out");
auto factor_node = helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, auto factor_node =
std::vector<float>(1, factor_)); helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, factor_);
if (input_info[0].dtype != P2ODataType::FP32) { if (input_info[0].dtype != P2ODataType::FP32) {
std::string x_cast_name = helper_->AutoCast( std::string x_cast_name = helper_->AutoCast(
{input_info[0].name}, input_info[0].dtype, P2ODataType::FP32); {input_info[0].name}, input_info[0].dtype, P2ODataType::FP32);
auto node = helper_->MakeNode("Pow", {x_cast_name, factor_node}); auto node = helper_->MakeNode("Pow", {x_cast_name, factor_node});
helper_->AutoCast(node->output(0), {output_info[0].name}, P2ODataType::FP32, helper_->AutoCast(node->output(0), {output_info[0].name}, P2ODataType::FP32,
input_info[0].dtype); input_info[0].dtype);
} else { } else {
helper_->MakeNode("Pow", {input_info[0].name, factor_node}, {output_info[0].name}); helper_->MakeNode("Pow", {input_info[0].name, factor_node},
{output_info[0].name});
} }
} }

View File

@@ -40,20 +40,20 @@ void ScaleMapper::Opset7() {
scale_info[0].name, scale_info[0].dtype, P2ODataType::FP32); scale_info[0].name, scale_info[0].dtype, P2ODataType::FP32);
out = helper_->MakeNode("Mul", {out, scale})->output(0); out = helper_->MakeNode("Mul", {out, scale})->output(0);
} else { } else {
auto scale = helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, auto scale =
std::vector<float>(1, scale_)); helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, scale_);
out = helper_->MakeNode("Mul", {out, scale})->output(0); out = helper_->MakeNode("Mul", {out, scale})->output(0);
} }
} }
if (!is_bias_0) { if (!is_bias_0) {
auto bias = helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, auto bias =
std::vector<float>(1, bias_)); helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, bias_);
out = helper_->MakeNode("Add", {out, bias})->output(0); out = helper_->MakeNode("Add", {out, bias})->output(0);
} }
} else { } else {
if (!is_bias_0) { if (!is_bias_0) {
auto bias = helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, auto bias =
std::vector<float>(1, bias_)); helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, bias_);
out = helper_->MakeNode("Add", {out, bias})->output(0); out = helper_->MakeNode("Add", {out, bias})->output(0);
} }
if (!is_scale_1 || HasInput("ScaleTensor")) { if (!is_scale_1 || HasInput("ScaleTensor")) {
@@ -63,8 +63,8 @@ void ScaleMapper::Opset7() {
scale_info[0].name, scale_info[0].dtype, P2ODataType::FP32); scale_info[0].name, scale_info[0].dtype, P2ODataType::FP32);
out = helper_->MakeNode("Mul", {out, scale})->output(0); out = helper_->MakeNode("Mul", {out, scale})->output(0);
} else { } else {
auto scale = helper_->Constant(ONNX_NAMESPACE::TensorProto::FLOAT, auto scale =
std::vector<float>(1, scale_)); helper_->Constant({}, ONNX_NAMESPACE::TensorProto::FLOAT, scale_);
out = helper_->MakeNode("Mul", {out, scale})->output(0); out = helper_->MakeNode("Mul", {out, scale})->output(0);
} }
} }

View File

@@ -36,9 +36,14 @@ void ScatterMapper::Opset11() {
std::string ids_node = helper_->AutoCast( std::string ids_node = helper_->AutoCast(
input_ids_info[0].name, input_ids_info[0].dtype, P2ODataType::INT64); input_ids_info[0].name, input_ids_info[0].dtype, P2ODataType::INT64);
std::vector<int64_t> shape_val = {input_ids_info[0].shape[0], 1}; std::string shape_node;
std::string shape_node = if (input_ids_info[0].Rank() == 0) {
helper_->Constant(GetOnnxDtype(P2ODataType::INT64), shape_val); std::vector<int64_t> shape = {1};
shape_node = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), shape);
} else {
std::vector<int64_t> shape = {input_ids_info[0].shape[0], 1};
shape_node = helper_->Constant(GetOnnxDtype(P2ODataType::INT64), shape);
}
auto reshape_index_node = auto reshape_index_node =
helper_->MakeNode("Reshape", {ids_node, shape_node}); helper_->MakeNode("Reshape", {ids_node, shape_node});
@@ -54,7 +59,7 @@ void ScatterMapper::Opset11() {
AddAttribute(scatter_nd_node, "reduction", "add"); AddAttribute(scatter_nd_node, "reduction", "add");
std::string zero_node = helper_->Constant( std::string zero_node = helper_->Constant(
{1}, GetOnnxDtype(input_x_info[0].dtype), static_cast<float>(0)); {}, GetOnnxDtype(input_x_info[0].dtype), static_cast<float>(0));
auto equal_node = auto equal_node =
helper_->MakeNode("Equal", {scatter_nd_node->output(0), zero_node}); helper_->MakeNode("Equal", {scatter_nd_node->output(0), zero_node});
@@ -62,15 +67,16 @@ void ScatterMapper::Opset11() {
std::string condition_node = helper_->AutoCast( std::string condition_node = helper_->AutoCast(
equal_node->output(0), P2ODataType::INT64, P2ODataType::BOOL); equal_node->output(0), P2ODataType::INT64, P2ODataType::BOOL);
helper_->MakeNode("Where", {condition_node, input_x_info[0].name, helper_->MakeNode(
scatter_nd_node->output(0)}, "Where",
{output_info[0].name}); {condition_node, input_x_info[0].name, scatter_nd_node->output(0)},
} else {
auto node = helper_->MakeNode(
"ScatterND", {input_x_info[0].name, reshape_index_node->output(0),
input_updates_info[0].name},
{output_info[0].name}); {output_info[0].name});
} else {
auto node =
helper_->MakeNode("ScatterND",
{input_x_info[0].name, reshape_index_node->output(0),
input_updates_info[0].name},
{output_info[0].name});
} }
} }

View File

@@ -49,7 +49,12 @@ void TileMapper::Opset7() {
} }
repeats = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, values); repeats = helper_->Constant(ONNX_NAMESPACE::TensorProto::INT64, values);
} }
helper_->MakeNode("Tile", {x_info[0].name, repeats}, {out_info[0].name}); if (x_info[0].Rank() == 0) {
auto unsqueeze = helper_->Unsqueeze(x_info[0].name, {0});
helper_->MakeNode("Tile", {unsqueeze, repeats}, {out_info[0].name});
} else {
helper_->MakeNode("Tile", {x_info[0].name, repeats}, {out_info[0].name});
}
} }
} // namespace paddle2onnx } // namespace paddle2onnx

7
paddle2onnx/mapper/tensor/top_k.cc Executable file → Normal file
View File

@@ -21,7 +21,12 @@ void TopKMapper::Opset11() {
auto x_info = GetInput("X"); auto x_info = GetInput("X");
auto output_info = GetOutput("Out"); auto output_info = GetOutput("Out");
auto indices_info = GetOutput("Indices"); auto indices_info = GetOutput("Indices");
if (x_info[0].Rank() == 0) {
helper_->MakeNode("Identity", {x_info[0].name}, {output_info[0].name});
helper_->Constant(indices_info[0].name, {},
ONNX_NAMESPACE::TensorProto::INT64, 0);
return;
}
std::string k = ""; std::string k = "";
if (HasInput("K")) { if (HasInput("K")) {
auto k_info = GetInput("K"); auto k_info = GetInput("K");

7
paddle2onnx/mapper/tensor/top_k_v2.cc Executable file → Normal file
View File

@@ -21,7 +21,12 @@ void TopKV2Mapper::Opset11() {
auto x_info = GetInput("X"); auto x_info = GetInput("X");
auto output_info = GetOutput("Out"); auto output_info = GetOutput("Out");
auto indices_info = GetOutput("Indices"); auto indices_info = GetOutput("Indices");
if (x_info[0].Rank() == 0) {
helper_->MakeNode("Identity", {x_info[0].name}, {output_info[0].name});
helper_->Constant(indices_info[0].name, {},
ONNX_NAMESPACE::TensorProto::INT64, 0);
return;
}
std::string k = ""; std::string k = "";
if (HasInput("K")) { if (HasInput("K")) {
auto k_info = GetInput("K"); auto k_info = GetInput("K");

View File

@@ -23,7 +23,11 @@ REGISTER_MAPPER(transpose2, Transpose2Mapper)
void Transpose2Mapper::Opset7() { void Transpose2Mapper::Opset7() {
auto input_info = GetInput("X"); auto input_info = GetInput("X");
auto output_info = GetOutput("Out"); auto output_info = GetOutput("Out");
if (input_info[0].Rank() == 0) {
helper_->MakeNode("Identity", {input_info[0].name}, {output_info[0].name});
return;
}
GetAttr("axis", &axis_);
auto node = helper_->MakeNode("Transpose", {input_info[0].name}, auto node = helper_->MakeNode("Transpose", {input_info[0].name},
{output_info[0].name}); {output_info[0].name});
AddAttribute(node, "perm", axis_); AddAttribute(node, "perm", axis_);

View File

@@ -24,13 +24,11 @@ class Transpose2Mapper : public Mapper {
public: public:
Transpose2Mapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id, Transpose2Mapper(const PaddleParser& p, OnnxHelper* helper, int64_t block_id,
int64_t op_id) int64_t op_id)
: Mapper(p, helper, block_id, op_id) { : Mapper(p, helper, block_id, op_id) {}
GetAttr("axis", &axis_);
}
void Opset7(); void Opset7();
private: private:
std::vector<int64_t> axis_; std::vector<int64_t> axis_ = {};
}; };
} // namespace paddle2onnx } // namespace paddle2onnx

View File

@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle2onnx/optimizer/convert_fp32_to_fp16.h" #include "paddle2onnx/optimizer/convert_fp32_to_fp16.h"
#include "paddle2onnx/utils/utils.h" #include "paddle2onnx/utils/utils.h"
namespace paddle2onnx { namespace paddle2onnx {
@@ -528,9 +529,8 @@ void ConvertFp32ToFp16::ConvertAttribute(ONNX_NAMESPACE::ModelProto* model) {
std::find(keep_type_node->input().begin(), std::find(keep_type_node->input().begin(),
keep_type_node->input().end(), keep_type_node->input().end(),
n->output()[0]) != keep_type_node->input().end(); n->output()[0]) != keep_type_node->input().end();
if (is_pre_node && if (is_pre_node && std::find(node_list.begin(), node_list.end(),
std::find(node_list.begin(), node_list.end(), n) == n) == node_list.end()) {
node_list.end()) {
node_list.push_back(keep_type_node); node_list.push_back(keep_type_node);
Assert( Assert(
n->op_type() == "Constant", n->op_type() == "Constant",
@@ -604,9 +604,8 @@ void ConvertFp32ToFp16::ConvertAttribute(ONNX_NAMESPACE::ModelProto* model) {
bool skip = bool skip =
std::find(graph_io_to_skip.begin(), graph_io_to_skip.end(), std::find(graph_io_to_skip.begin(), graph_io_to_skip.end(),
input->name()) != graph_io_to_skip.end(); input->name()) != graph_io_to_skip.end();
if (!skip && if (!skip && input->type().tensor_type().elem_type() ==
input->type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto::FLOAT) {
ONNX_NAMESPACE::TensorProto::FLOAT) {
input->mutable_type()->mutable_tensor_type()->set_elem_type( input->mutable_type()->mutable_tensor_type()->set_elem_type(
ONNX_NAMESPACE::TensorProto::FLOAT16); ONNX_NAMESPACE::TensorProto::FLOAT16);
value_info_list.push_back(input); value_info_list.push_back(input);
@@ -617,9 +616,8 @@ void ConvertFp32ToFp16::ConvertAttribute(ONNX_NAMESPACE::ModelProto* model) {
bool skip = bool skip =
std::find(graph_io_to_skip.begin(), graph_io_to_skip.end(), std::find(graph_io_to_skip.begin(), graph_io_to_skip.end(),
output->name()) != graph_io_to_skip.end(); output->name()) != graph_io_to_skip.end();
if (!skip && if (!skip && output->type().tensor_type().elem_type() ==
output->type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto::FLOAT) {
ONNX_NAMESPACE::TensorProto::FLOAT) {
output->mutable_type()->mutable_tensor_type()->set_elem_type( output->mutable_type()->mutable_tensor_type()->set_elem_type(
ONNX_NAMESPACE::TensorProto::FLOAT16); ONNX_NAMESPACE::TensorProto::FLOAT16);
value_info_list.push_back(output); value_info_list.push_back(output);
@@ -819,9 +817,8 @@ bool ConvertFp32ToFp16::IsFP16Model(const ONNX_NAMESPACE::ModelProto& model) {
} }
void ConvertFp32ToFp16::Convert(ONNX_NAMESPACE::ModelProto* model) { void ConvertFp32ToFp16::Convert(ONNX_NAMESPACE::ModelProto* model) {
if (op_block_list_.empty()) { op_block_list_.insert(op_block_list_.end(), DEFAULT_OP_BLOCK_LIST.begin(),
op_block_list_ = DEFAULT_OP_BLOCK_LIST; DEFAULT_OP_BLOCK_LIST.end());
}
if (custom_ops_.size()) { if (custom_ops_.size()) {
op_block_list_.insert(op_block_list_.end(), custom_ops_.begin(), op_block_list_.insert(op_block_list_.end(), custom_ops_.begin(),
custom_ops_.end()); custom_ops_.end());

7
paddle2onnx/optimizer/convert_fp32_to_fp16.h Executable file → Normal file
View File

@@ -14,12 +14,12 @@
#pragma once #pragma once
#include <onnx/onnx_pb.h> #include <onnx/onnx_pb.h>
#include <onnx/shape_inference/implementation.h>
#include <cmath> #include <cmath>
#include <fstream> #include <fstream>
#include <iomanip> #include <iomanip>
#include <onnx/shape_inference/implementation.h>
#include "paddle2onnx/mapper/mapper.h" #include "paddle2onnx/mapper/mapper.h"
#include "paddle2onnx/parser/parser.h" #include "paddle2onnx/parser/parser.h"
namespace paddle2onnx { namespace paddle2onnx {
@@ -137,6 +137,11 @@ struct ConvertFp32ToFp16 {
} }
} }
} }
void AddDisabledOpTypes(const std::vector<std::string>& disable_fp16_ops) {
op_block_list_.insert(op_block_list_.end(), disable_fp16_ops.begin(),
disable_fp16_ops.end());
}
// If the input ONNX model is a FP16 model, return True // If the input ONNX model is a FP16 model, return True
bool IsFP16Model(const ONNX_NAMESPACE::ModelProto& model); bool IsFP16Model(const ONNX_NAMESPACE::ModelProto& model);