From ee2c6136fca3623e355e6ce631fa42b411decb3d Mon Sep 17 00:00:00 2001 From: Jack Zhou Date: Sun, 30 Oct 2022 17:04:05 +0800 Subject: [PATCH] [Model] Upgrade uie (#458) * Upgrade uie c++ implement * upgrade python UIEModel inherit FastDeployModel * Add schema language support; Skip infer when no prompts * Adjust the schema language arg pos * Add schema_language for python and cpp * update pybind for uie * Fix the args of uie * Add SchemaLanguage --- examples/text/uie/cpp/README.md | 11 ++- examples/text/uie/python/README.md | 5 +- examples/text/uie/python/infer.py | 5 +- fastdeploy/text/uie/model.cc | 105 ++++++++++++++++--------- fastdeploy/text/uie/model.h | 26 ++++-- fastdeploy/text/uie/uie_pybind.cc | 22 ++++-- python/fastdeploy/text/__init__.py | 1 + python/fastdeploy/text/uie/__init__.py | 19 +++-- 8 files changed, 128 insertions(+), 66 deletions(-) diff --git a/examples/text/uie/cpp/README.md b/examples/text/uie/cpp/README.md index ef370263f..8395fe692 100644 --- a/examples/text/uie/cpp/README.md +++ b/examples/text/uie/cpp/README.md @@ -411,20 +411,24 @@ UIEModel( const std::vector& schema, const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption(), - const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE); + const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE, + SchemaLanguage schema_language = SchemaLanguage::ZH); UIEModel( const std::string& model_file, const std::string& params_file, const std::string& vocab_file, float position_prob, size_t max_length, const SchemaNode& schema, const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption(), - const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE); + const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE, + SchemaLanguage schema_language = SchemaLanguage::ZH); UIEModel( const std::string& model_file, const std::string& params_file, const std::string& vocab_file, float position_prob, size_t max_length, const std::vector& schema, const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption(), - const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE); + const fastdeploy::ModelFormat& model_format = + fastdeploy::ModelFormat::PADDLE, + SchemaLanguage schema_language = SchemaLanguage::ZH); ``` UIE模型加载和初始化,其中model_file, params_file为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/model_zoo/uie/README.md#%E6%A8%A1%E5%9E%8B%E9%83%A8%E7%BD%B2)。 @@ -439,6 +443,7 @@ UIE模型加载和初始化,其中model_file, params_file为训练模型导出 > * **schema**(list(SchemaNode) | SchemaNode | list(str)): 抽取任务的目标模式。 > * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 > * **model_format**(ModelFormat): 模型格式,默认为Paddle格式 +> * **schema_language** (SchemaLanguage): Schema 语言,默认为ZH(中文),目前支持的语言种类包括:ZH(中文),EN(英文)。 #### SetSchema函数 diff --git a/examples/text/uie/python/README.md b/examples/text/uie/python/README.md index 0133c1525..1c6124e32 100644 --- a/examples/text/uie/python/README.md +++ b/examples/text/uie/python/README.md @@ -329,7 +329,9 @@ fd.text.uie.UIEModel(model_file, position_prob=0.5, max_length=128, schema=[], - runtime_option=None,model_format=ModelFormat.PADDLE) + runtime_option=None, + model_format=ModelFormat.PADDLE, + schema_language=SchemaLanguage.ZH) ``` UIEModel模型加载和初始化,其中`model_file`, `params_file`为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/model_zoo/uie/README.md#%E6%A8%A1%E5%9E%8B%E9%83%A8%E7%BD%B2),`vocab_file`为词表文件,UIE模型的词表可在[UIE配置文件](https://github.com/PaddlePaddle/PaddleNLP/blob/5401f01af85f1c73d8017c6b3476242fce1e6d52/model_zoo/uie/utils.py)中下载相应的UIE模型的vocab_file。 @@ -344,6 +346,7 @@ UIEModel模型加载和初始化,其中`model_file`, `params_file`为训练模 > * **schema**(list|dict): 抽取任务的目标信息。 > * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置 > * **model_format**(ModelFormat): 模型格式,默认为Paddle格式 +> * **schema_language**(SchemaLanguage): Schema语言。默认为ZH(中文),目前支持的语言种类包括:ZH(中文),EN(英文)。 ### set_schema函数 diff --git a/examples/text/uie/python/infer.py b/examples/text/uie/python/infer.py index a72ce5542..46df986e1 100644 --- a/examples/text/uie/python/infer.py +++ b/examples/text/uie/python/infer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import fastdeploy -from fastdeploy.text import UIEModel +from fastdeploy.text import UIEModel, SchemaLanguage import os from pprint import pprint @@ -80,7 +80,8 @@ if __name__ == "__main__": position_prob=0.5, max_length=128, schema=schema, - runtime_option=runtime_option) + runtime_option=runtime_option, + schema_language=SchemaLanguage.ZH) print("1. Named Entity Recognition Task") print(f"The extraction schema: {schema}") diff --git a/fastdeploy/text/uie/model.cc b/fastdeploy/text/uie/model.cc index 6f71dd394..c04d4d6d3 100644 --- a/fastdeploy/text/uie/model.cc +++ b/fastdeploy/text/uie/model.cc @@ -165,14 +165,16 @@ UIEModel::UIEModel(const std::string& model_file, const std::string& vocab_file, float position_prob, size_t max_length, const std::vector& schema, const fastdeploy::RuntimeOption& custom_option, - const fastdeploy::ModelFormat& model_format) + const fastdeploy::ModelFormat& model_format, + SchemaLanguage schema_language) : max_length_(max_length), position_prob_(position_prob), + schema_language_(schema_language), tokenizer_(vocab_file) { - runtime_option_ = custom_option; - runtime_option_.model_format = model_format; - runtime_option_.SetModelPath(model_file, params_file); - runtime_.Init(runtime_option_); + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.SetModelPath(model_file, params_file); + initialized = Initialize(); SetSchema(schema); tokenizer_.EnableTruncMethod( max_length, 0, faster_tokenizer::core::Direction::RIGHT, @@ -184,14 +186,16 @@ UIEModel::UIEModel(const std::string& model_file, const std::string& vocab_file, float position_prob, size_t max_length, const std::vector& schema, const fastdeploy::RuntimeOption& custom_option, - const fastdeploy::ModelFormat& model_format) + const fastdeploy::ModelFormat& model_format, + SchemaLanguage schema_language) : max_length_(max_length), position_prob_(position_prob), + schema_language_(schema_language), tokenizer_(vocab_file) { - runtime_option_ = custom_option; - runtime_option_.model_format = model_format; - runtime_option_.SetModelPath(model_file, params_file); - runtime_.Init(runtime_option_); + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.SetModelPath(model_file, params_file); + initialized = Initialize(); SetSchema(schema); tokenizer_.EnableTruncMethod( max_length, 0, faster_tokenizer::core::Direction::RIGHT, @@ -203,20 +207,33 @@ UIEModel::UIEModel(const std::string& model_file, const std::string& vocab_file, float position_prob, size_t max_length, const SchemaNode& schema, const fastdeploy::RuntimeOption& custom_option, - const fastdeploy::ModelFormat& model_format) + const fastdeploy::ModelFormat& model_format, + SchemaLanguage schema_language) : max_length_(max_length), position_prob_(position_prob), + schema_language_(schema_language), tokenizer_(vocab_file) { - runtime_option_ = custom_option; - runtime_option_.model_format = model_format; - runtime_option_.SetModelPath(model_file, params_file); - runtime_.Init(runtime_option_); + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.SetModelPath(model_file, params_file); + initialized = Initialize(); SetSchema(schema); tokenizer_.EnableTruncMethod( max_length, 0, faster_tokenizer::core::Direction::RIGHT, faster_tokenizer::core::TruncStrategy::LONGEST_FIRST); } +bool UIEModel::Initialize() { + SetValidBackend(); + return InitRuntime(); +} + +void UIEModel::SetValidBackend() { + // TODO(zhoushunjie): Add lite backend in future + valid_cpu_backends = {Backend::ORT, Backend::OPENVINO, Backend::PDINFER}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; +} + void UIEModel::SetSchema(const std::vector& schema) { schema_ = fastdeploy::utils::make_unique(schema); } @@ -463,7 +480,7 @@ void UIEModel::AutoJoiner(const std::vector& short_texts, *results = std::move(final_result); } -void UIEModel::ConstructTextsAndPrompts( +bool UIEModel::ConstructTextsAndPrompts( const std::vector& raw_texts, const std::string& node_name, const std::vector> node_prefix, std::vector* input_texts, std::vector* prompts, @@ -496,6 +513,10 @@ void UIEModel::ConstructTextsAndPrompts( } } + if (prompts->size() == 0) { + return false; + } + // Shortten the input texts and prompts auto max_prompt_iter = std::max_element( prompts->begin(), prompts->end(), @@ -506,7 +527,6 @@ void UIEModel::ConstructTextsAndPrompts( rhs.c_str(), rhs.length()); return lhs_ulen < rhs_ulen; }); - auto max_prompt_len = faster_tokenizer::utils::GetUnicodeLenFromUTF8( max_prompt_iter->c_str(), max_prompt_iter->length()); auto max_predict_len = max_length_ - 3 - max_prompt_len; @@ -521,6 +541,7 @@ void UIEModel::ConstructTextsAndPrompts( } (*input_texts) = std::move(short_texts); (*prompts) = std::move(short_texts_prompts); + return true; } void UIEModel::Preprocess( @@ -542,10 +563,10 @@ void UIEModel::Preprocess( if (batch_size > 0) { seq_len = (*encodings)[0].GetIds().size(); } - inputs->resize(runtime_.NumInputs()); - for (int i = 0; i < runtime_.NumInputs(); ++i) { + inputs->resize(NumInputsOfRuntime()); + for (int i = 0; i < NumInputsOfRuntime(); ++i) { (*inputs)[i].Allocate({batch_size, seq_len}, fastdeploy::FDDataType::INT64, - runtime_.GetInputInfo(i).name); + InputInfoOfRuntime(i).name); } // 2.2 Set the value of data @@ -619,8 +640,13 @@ void UIEModel::ConstructChildPromptPrefix( auto&& input_mapping_item = input_mapping_with_raw_texts[i]; for (auto&& idx : input_mapping_item) { for (int j = 0; j < results_list[idx].size(); ++j) { - // Note(zhoushunjie): It's useful for Chinese model. - auto prefix_str = results_list[idx][j].text_ + "\xe7\x9a\x84"; + std::string prefix_str; + if (schema_language_ == SchemaLanguage::ZH) { + // Note(zhoushunjie): It means "of" in Chinese. + prefix_str = results_list[idx][j].text_ + "\xe7\x9a\x84"; + } else { + prefix_str = " of " + results_list[idx][j].text_; + } (*prefix)[i].push_back(prefix_str); } } @@ -677,7 +703,7 @@ void UIEModel::ConstructChildRelations( } } for (int i = 0; i < curr_relations.size(); ++i) { - for (int j = 0; j < new_relations[i].size(); ++j) { + for (int j = 0; j < curr_relations[i].size(); ++j) { if (curr_relations[i][j]->relation_.count(node_name)) { auto& curr_relation = curr_relations[i][j]->relation_[node_name]; for (auto&& curr_result_ref : curr_relation) { @@ -706,24 +732,27 @@ void UIEModel::Predict( std::vector short_input_texts; std::vector short_prompts; // 1. Construct texts and prompts from raw text - ConstructTextsAndPrompts( + bool has_prompt = ConstructTextsAndPrompts( texts, node.name_, node.prefix_, &short_input_texts, &short_prompts, &input_mapping_with_raw_texts, &input_mapping_with_short_text); - - // 2. Convert texts and prompts to FDTensor - std::vector inputs; - std::vector encodings; - Preprocess(short_input_texts, short_prompts, &encodings, &inputs); - - // 3. Infer - std::vector outputs(runtime_.NumOutputs()); - runtime_.Infer(inputs, &outputs); - - // 4. Convert FDTensor to UIEResult std::vector> results_list; - Postprocess(outputs, encodings, short_input_texts, short_prompts, - input_mapping_with_short_text, &results_list); + if (has_prompt) { + // 2. Convert texts and prompts to FDTensor + std::vector inputs; + std::vector encodings; + Preprocess(short_input_texts, short_prompts, &encodings, &inputs); + // 3. Infer + std::vector outputs(NumOutputsOfRuntime()); + if (!Infer(inputs, &outputs)) { + FDERROR << "Failed to inference while using model:" << ModelName() + << "." << std::endl; + } + + // 4. Convert FDTensor to UIEResult + Postprocess(outputs, encodings, short_input_texts, short_prompts, + input_mapping_with_short_text, &results_list); + } // 5. Construct the new relation of the UIEResult std::vector> relations; ConstructChildRelations(node.relations_, input_mapping_with_raw_texts, @@ -742,4 +771,4 @@ void UIEModel::Predict( } } // namespace text -} // namespace fastdeploy \ No newline at end of file +} // namespace fastdeploy diff --git a/fastdeploy/text/uie/model.h b/fastdeploy/text/uie/model.h index e867c21a8..48b21f8b0 100644 --- a/fastdeploy/text/uie/model.h +++ b/fastdeploy/text/uie/model.h @@ -75,6 +75,11 @@ struct FASTDEPLOY_DECL SchemaNode { } }; +enum SchemaLanguage { + ZH, // Chinese + EN // English +}; + struct Schema { explicit Schema(const std::string& schema, const std::string& name = "root"); explicit Schema(const std::vector& schema_list, @@ -89,7 +94,7 @@ struct Schema { friend class UIEModel; }; -struct FASTDEPLOY_DECL UIEModel { +struct FASTDEPLOY_DECL UIEModel : public FastDeployModel { public: UIEModel(const std::string& model_file, const std::string& params_file, const std::string& vocab_file, float position_prob, @@ -97,26 +102,30 @@ struct FASTDEPLOY_DECL UIEModel { const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption(), const fastdeploy::ModelFormat& model_format = - fastdeploy::ModelFormat::PADDLE); + fastdeploy::ModelFormat::PADDLE, + SchemaLanguage schema_language = SchemaLanguage::ZH); UIEModel(const std::string& model_file, const std::string& params_file, const std::string& vocab_file, float position_prob, size_t max_length, const SchemaNode& schema, const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption(), const fastdeploy::ModelFormat& model_format = - fastdeploy::ModelFormat::PADDLE); + fastdeploy::ModelFormat::PADDLE, + SchemaLanguage schema_language = SchemaLanguage::ZH); UIEModel(const std::string& model_file, const std::string& params_file, const std::string& vocab_file, float position_prob, size_t max_length, const std::vector& schema, const fastdeploy::RuntimeOption& custom_option = fastdeploy::RuntimeOption(), const fastdeploy::ModelFormat& model_format = - fastdeploy::ModelFormat::PADDLE); + fastdeploy::ModelFormat::PADDLE, + SchemaLanguage schema_language = SchemaLanguage::ZH); + virtual std::string ModelName() const { return "UIEModel"; } void SetSchema(const std::vector& schema); void SetSchema(const std::vector& schema); void SetSchema(const SchemaNode& schema); - void ConstructTextsAndPrompts( + bool ConstructTextsAndPrompts( const std::vector& raw_texts, const std::string& node_name, const std::vector> node_prefix, std::vector* input_texts, std::vector* prompts, @@ -150,7 +159,7 @@ struct FASTDEPLOY_DECL UIEModel { std::vector>>* results); - private: + protected: using IDX_PROB = std::pair; struct IdxProbCmp { bool operator()(const std::pair& lhs, @@ -161,6 +170,8 @@ struct FASTDEPLOY_DECL UIEModel { faster_tokenizer::core::Offset offset_; bool is_prompt_; }; + void SetValidBackend(); + bool Initialize(); void AutoSplitter(const std::vector& texts, size_t max_length, std::vector* short_texts, std::vector>* input_mapping); @@ -185,11 +196,10 @@ struct FASTDEPLOY_DECL UIEModel { const std::vector>& span_idxs, const std::vector>& probs, std::vector>* results) const; - fastdeploy::RuntimeOption runtime_option_; - fastdeploy::Runtime runtime_; std::unique_ptr schema_; size_t max_length_; float position_prob_; + SchemaLanguage schema_language_; faster_tokenizer::tokenizers_impl::ErnieFasterTokenizer tokenizer_; }; diff --git a/fastdeploy/text/uie/uie_pybind.cc b/fastdeploy/text/uie/uie_pybind.cc index 9e47c326f..146dcc2c6 100644 --- a/fastdeploy/text/uie/uie_pybind.cc +++ b/fastdeploy/text/uie/uie_pybind.cc @@ -28,26 +28,34 @@ void BindUIE(pybind11::module& m) { .def_readwrite("relations", &text::SchemaNode::relations_) .def_readwrite("children", &text::SchemaNode::children_); - py::class_(m, "UIEModel") + py::enum_(m, "SchemaLanguage", py::arithmetic(), + "The language of schema.") + .value("ZH", text::SchemaLanguage::ZH) + .value("EN", text::SchemaLanguage::EN); + + py::class_(m, "UIEModel") .def(py::init, RuntimeOption, ModelFormat>(), + std::vector, RuntimeOption, ModelFormat, text::SchemaLanguage>(), py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"), py::arg("position_prob"), py::arg("max_length"), py::arg("schema"), py::arg("custom_option") = fastdeploy::RuntimeOption(), - py::arg("model_format") = fastdeploy::ModelFormat::PADDLE) + py::arg("model_format") = fastdeploy::ModelFormat::PADDLE, + py::arg("schema_language") = text::SchemaLanguage::ZH) .def( py::init, RuntimeOption, ModelFormat>(), + std::vector, RuntimeOption, ModelFormat, text::SchemaLanguage>(), py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"), py::arg("position_prob"), py::arg("max_length"), py::arg("schema"), py::arg("custom_option") = fastdeploy::RuntimeOption(), - py::arg("model_format") = fastdeploy::ModelFormat::PADDLE) + py::arg("model_format") = fastdeploy::ModelFormat::PADDLE, + py::arg("schema_language") = text::SchemaLanguage::ZH) .def(py::init(), + text::SchemaNode, RuntimeOption, ModelFormat, text::SchemaLanguage>(), py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"), py::arg("position_prob"), py::arg("max_length"), py::arg("schema"), py::arg("custom_option") = fastdeploy::RuntimeOption(), - py::arg("model_format") = fastdeploy::ModelFormat::PADDLE) + py::arg("model_format") = fastdeploy::ModelFormat::PADDLE, + py::arg("schema_language") = text::SchemaLanguage::ZH) .def("set_schema", static_cast&)>(&text::UIEModel::SetSchema), diff --git a/python/fastdeploy/text/__init__.py b/python/fastdeploy/text/__init__.py index 63af1fdd4..4d8b20695 100644 --- a/python/fastdeploy/text/__init__.py +++ b/python/fastdeploy/text/__init__.py @@ -15,3 +15,4 @@ from __future__ import absolute_import from . import uie from .uie import UIEModel +from .uie import SchemaLanguage diff --git a/python/fastdeploy/text/uie/__init__.py b/python/fastdeploy/text/uie/__init__.py index b6676159a..89c7147af 100644 --- a/python/fastdeploy/text/uie/__init__.py +++ b/python/fastdeploy/text/uie/__init__.py @@ -15,11 +15,14 @@ from __future__ import absolute_import import logging -from ... import ModelFormat -from ... import RuntimeOption +from ... import RuntimeOption, FastDeployModel, ModelFormat from ... import c_lib_wrap as C +class SchemaLanguage(C.text.SchemaLanguage): + pass + + class SchemaNode(object): def __init__(self, name, children=[]): schema_node_children = [] @@ -38,7 +41,7 @@ class SchemaNode(object): self._schema_node_children = schema_node_children -class UIEModel(object): +class UIEModel(FastDeployModel): def __init__(self, model_file, params_file, @@ -47,7 +50,8 @@ class UIEModel(object): max_length=128, schema=[], runtime_option=RuntimeOption(), - model_format=ModelFormat.PADDLE): + model_format=ModelFormat.PADDLE, + schema_language=SchemaLanguage.ZH): if isinstance(schema, list): schema = SchemaNode("", schema)._schema_node_children elif isinstance(schema, dict): @@ -57,9 +61,10 @@ class UIEModel(object): schema = schema_tmp else: assert "The type of schema should be list or dict." - self._model = C.text.UIEModel(model_file, params_file, vocab_file, - position_prob, max_length, schema, - runtime_option._option, model_format) + self._model = C.text.UIEModel( + model_file, params_file, vocab_file, position_prob, max_length, + schema, runtime_option._option, model_format, schema_language) + assert self.initialized, "UIEModel initialize failed." def set_schema(self, schema): if isinstance(schema, list):