mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Model] Upgrade uie (#458)
* Upgrade uie c++ implement * upgrade python UIEModel inherit FastDeployModel * Add schema language support; Skip infer when no prompts * Adjust the schema language arg pos * Add schema_language for python and cpp * update pybind for uie * Fix the args of uie * Add SchemaLanguage
This commit is contained in:
@@ -411,20 +411,24 @@ UIEModel(
|
|||||||
const std::vector<std::string>& schema,
|
const std::vector<std::string>& schema,
|
||||||
const fastdeploy::RuntimeOption& custom_option =
|
const fastdeploy::RuntimeOption& custom_option =
|
||||||
fastdeploy::RuntimeOption(),
|
fastdeploy::RuntimeOption(),
|
||||||
const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE);
|
const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE,
|
||||||
|
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||||
UIEModel(
|
UIEModel(
|
||||||
const std::string& model_file, const std::string& params_file,
|
const std::string& model_file, const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob, size_t max_length,
|
const std::string& vocab_file, float position_prob, size_t max_length,
|
||||||
const SchemaNode& schema, const fastdeploy::RuntimeOption& custom_option =
|
const SchemaNode& schema, const fastdeploy::RuntimeOption& custom_option =
|
||||||
fastdeploy::RuntimeOption(),
|
fastdeploy::RuntimeOption(),
|
||||||
const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE);
|
const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE,
|
||||||
|
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||||
UIEModel(
|
UIEModel(
|
||||||
const std::string& model_file, const std::string& params_file,
|
const std::string& model_file, const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob, size_t max_length,
|
const std::string& vocab_file, float position_prob, size_t max_length,
|
||||||
const std::vector<SchemaNode>& schema,
|
const std::vector<SchemaNode>& schema,
|
||||||
const fastdeploy::RuntimeOption& custom_option =
|
const fastdeploy::RuntimeOption& custom_option =
|
||||||
fastdeploy::RuntimeOption(),
|
fastdeploy::RuntimeOption(),
|
||||||
const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE);
|
const fastdeploy::ModelFormat& model_format =
|
||||||
|
fastdeploy::ModelFormat::PADDLE,
|
||||||
|
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||||
```
|
```
|
||||||
|
|
||||||
UIE模型加载和初始化,其中model_file, params_file为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/model_zoo/uie/README.md#%E6%A8%A1%E5%9E%8B%E9%83%A8%E7%BD%B2)。
|
UIE模型加载和初始化,其中model_file, params_file为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/model_zoo/uie/README.md#%E6%A8%A1%E5%9E%8B%E9%83%A8%E7%BD%B2)。
|
||||||
@@ -439,6 +443,7 @@ UIE模型加载和初始化,其中model_file, params_file为训练模型导出
|
|||||||
> * **schema**(list(SchemaNode) | SchemaNode | list(str)): 抽取任务的目标模式。
|
> * **schema**(list(SchemaNode) | SchemaNode | list(str)): 抽取任务的目标模式。
|
||||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||||
> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式
|
> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式
|
||||||
|
> * **schema_language** (SchemaLanguage): Schema 语言,默认为ZH(中文),目前支持的语言种类包括:ZH(中文),EN(英文)。
|
||||||
|
|
||||||
#### SetSchema函数
|
#### SetSchema函数
|
||||||
|
|
||||||
|
@@ -329,7 +329,9 @@ fd.text.uie.UIEModel(model_file,
|
|||||||
position_prob=0.5,
|
position_prob=0.5,
|
||||||
max_length=128,
|
max_length=128,
|
||||||
schema=[],
|
schema=[],
|
||||||
runtime_option=None,model_format=ModelFormat.PADDLE)
|
runtime_option=None,
|
||||||
|
model_format=ModelFormat.PADDLE,
|
||||||
|
schema_language=SchemaLanguage.ZH)
|
||||||
```
|
```
|
||||||
|
|
||||||
UIEModel模型加载和初始化,其中`model_file`, `params_file`为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/model_zoo/uie/README.md#%E6%A8%A1%E5%9E%8B%E9%83%A8%E7%BD%B2),`vocab_file`为词表文件,UIE模型的词表可在[UIE配置文件](https://github.com/PaddlePaddle/PaddleNLP/blob/5401f01af85f1c73d8017c6b3476242fce1e6d52/model_zoo/uie/utils.py)中下载相应的UIE模型的vocab_file。
|
UIEModel模型加载和初始化,其中`model_file`, `params_file`为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/model_zoo/uie/README.md#%E6%A8%A1%E5%9E%8B%E9%83%A8%E7%BD%B2),`vocab_file`为词表文件,UIE模型的词表可在[UIE配置文件](https://github.com/PaddlePaddle/PaddleNLP/blob/5401f01af85f1c73d8017c6b3476242fce1e6d52/model_zoo/uie/utils.py)中下载相应的UIE模型的vocab_file。
|
||||||
@@ -344,6 +346,7 @@ UIEModel模型加载和初始化,其中`model_file`, `params_file`为训练模
|
|||||||
> * **schema**(list|dict): 抽取任务的目标信息。
|
> * **schema**(list|dict): 抽取任务的目标信息。
|
||||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||||
> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式
|
> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式
|
||||||
|
> * **schema_language**(SchemaLanguage): Schema语言。默认为ZH(中文),目前支持的语言种类包括:ZH(中文),EN(英文)。
|
||||||
|
|
||||||
### set_schema函数
|
### set_schema函数
|
||||||
|
|
||||||
|
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import fastdeploy
|
import fastdeploy
|
||||||
from fastdeploy.text import UIEModel
|
from fastdeploy.text import UIEModel, SchemaLanguage
|
||||||
import os
|
import os
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
@@ -80,7 +80,8 @@ if __name__ == "__main__":
|
|||||||
position_prob=0.5,
|
position_prob=0.5,
|
||||||
max_length=128,
|
max_length=128,
|
||||||
schema=schema,
|
schema=schema,
|
||||||
runtime_option=runtime_option)
|
runtime_option=runtime_option,
|
||||||
|
schema_language=SchemaLanguage.ZH)
|
||||||
|
|
||||||
print("1. Named Entity Recognition Task")
|
print("1. Named Entity Recognition Task")
|
||||||
print(f"The extraction schema: {schema}")
|
print(f"The extraction schema: {schema}")
|
||||||
|
@@ -165,14 +165,16 @@ UIEModel::UIEModel(const std::string& model_file,
|
|||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const std::vector<std::string>& schema,
|
size_t max_length, const std::vector<std::string>& schema,
|
||||||
const fastdeploy::RuntimeOption& custom_option,
|
const fastdeploy::RuntimeOption& custom_option,
|
||||||
const fastdeploy::ModelFormat& model_format)
|
const fastdeploy::ModelFormat& model_format,
|
||||||
|
SchemaLanguage schema_language)
|
||||||
: max_length_(max_length),
|
: max_length_(max_length),
|
||||||
position_prob_(position_prob),
|
position_prob_(position_prob),
|
||||||
|
schema_language_(schema_language),
|
||||||
tokenizer_(vocab_file) {
|
tokenizer_(vocab_file) {
|
||||||
runtime_option_ = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option_.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
runtime_option_.SetModelPath(model_file, params_file);
|
runtime_option.SetModelPath(model_file, params_file);
|
||||||
runtime_.Init(runtime_option_);
|
initialized = Initialize();
|
||||||
SetSchema(schema);
|
SetSchema(schema);
|
||||||
tokenizer_.EnableTruncMethod(
|
tokenizer_.EnableTruncMethod(
|
||||||
max_length, 0, faster_tokenizer::core::Direction::RIGHT,
|
max_length, 0, faster_tokenizer::core::Direction::RIGHT,
|
||||||
@@ -184,14 +186,16 @@ UIEModel::UIEModel(const std::string& model_file,
|
|||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const std::vector<SchemaNode>& schema,
|
size_t max_length, const std::vector<SchemaNode>& schema,
|
||||||
const fastdeploy::RuntimeOption& custom_option,
|
const fastdeploy::RuntimeOption& custom_option,
|
||||||
const fastdeploy::ModelFormat& model_format)
|
const fastdeploy::ModelFormat& model_format,
|
||||||
|
SchemaLanguage schema_language)
|
||||||
: max_length_(max_length),
|
: max_length_(max_length),
|
||||||
position_prob_(position_prob),
|
position_prob_(position_prob),
|
||||||
|
schema_language_(schema_language),
|
||||||
tokenizer_(vocab_file) {
|
tokenizer_(vocab_file) {
|
||||||
runtime_option_ = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option_.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
runtime_option_.SetModelPath(model_file, params_file);
|
runtime_option.SetModelPath(model_file, params_file);
|
||||||
runtime_.Init(runtime_option_);
|
initialized = Initialize();
|
||||||
SetSchema(schema);
|
SetSchema(schema);
|
||||||
tokenizer_.EnableTruncMethod(
|
tokenizer_.EnableTruncMethod(
|
||||||
max_length, 0, faster_tokenizer::core::Direction::RIGHT,
|
max_length, 0, faster_tokenizer::core::Direction::RIGHT,
|
||||||
@@ -203,20 +207,33 @@ UIEModel::UIEModel(const std::string& model_file,
|
|||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const SchemaNode& schema,
|
size_t max_length, const SchemaNode& schema,
|
||||||
const fastdeploy::RuntimeOption& custom_option,
|
const fastdeploy::RuntimeOption& custom_option,
|
||||||
const fastdeploy::ModelFormat& model_format)
|
const fastdeploy::ModelFormat& model_format,
|
||||||
|
SchemaLanguage schema_language)
|
||||||
: max_length_(max_length),
|
: max_length_(max_length),
|
||||||
position_prob_(position_prob),
|
position_prob_(position_prob),
|
||||||
|
schema_language_(schema_language),
|
||||||
tokenizer_(vocab_file) {
|
tokenizer_(vocab_file) {
|
||||||
runtime_option_ = custom_option;
|
runtime_option = custom_option;
|
||||||
runtime_option_.model_format = model_format;
|
runtime_option.model_format = model_format;
|
||||||
runtime_option_.SetModelPath(model_file, params_file);
|
runtime_option.SetModelPath(model_file, params_file);
|
||||||
runtime_.Init(runtime_option_);
|
initialized = Initialize();
|
||||||
SetSchema(schema);
|
SetSchema(schema);
|
||||||
tokenizer_.EnableTruncMethod(
|
tokenizer_.EnableTruncMethod(
|
||||||
max_length, 0, faster_tokenizer::core::Direction::RIGHT,
|
max_length, 0, faster_tokenizer::core::Direction::RIGHT,
|
||||||
faster_tokenizer::core::TruncStrategy::LONGEST_FIRST);
|
faster_tokenizer::core::TruncStrategy::LONGEST_FIRST);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool UIEModel::Initialize() {
|
||||||
|
SetValidBackend();
|
||||||
|
return InitRuntime();
|
||||||
|
}
|
||||||
|
|
||||||
|
void UIEModel::SetValidBackend() {
|
||||||
|
// TODO(zhoushunjie): Add lite backend in future
|
||||||
|
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO, Backend::PDINFER};
|
||||||
|
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||||
|
}
|
||||||
|
|
||||||
void UIEModel::SetSchema(const std::vector<std::string>& schema) {
|
void UIEModel::SetSchema(const std::vector<std::string>& schema) {
|
||||||
schema_ = fastdeploy::utils::make_unique<Schema>(schema);
|
schema_ = fastdeploy::utils::make_unique<Schema>(schema);
|
||||||
}
|
}
|
||||||
@@ -463,7 +480,7 @@ void UIEModel::AutoJoiner(const std::vector<std::string>& short_texts,
|
|||||||
*results = std::move(final_result);
|
*results = std::move(final_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
void UIEModel::ConstructTextsAndPrompts(
|
bool UIEModel::ConstructTextsAndPrompts(
|
||||||
const std::vector<std::string>& raw_texts, const std::string& node_name,
|
const std::vector<std::string>& raw_texts, const std::string& node_name,
|
||||||
const std::vector<std::vector<std::string>> node_prefix,
|
const std::vector<std::vector<std::string>> node_prefix,
|
||||||
std::vector<std::string>* input_texts, std::vector<std::string>* prompts,
|
std::vector<std::string>* input_texts, std::vector<std::string>* prompts,
|
||||||
@@ -496,6 +513,10 @@ void UIEModel::ConstructTextsAndPrompts(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (prompts->size() == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Shortten the input texts and prompts
|
// Shortten the input texts and prompts
|
||||||
auto max_prompt_iter = std::max_element(
|
auto max_prompt_iter = std::max_element(
|
||||||
prompts->begin(), prompts->end(),
|
prompts->begin(), prompts->end(),
|
||||||
@@ -506,7 +527,6 @@ void UIEModel::ConstructTextsAndPrompts(
|
|||||||
rhs.c_str(), rhs.length());
|
rhs.c_str(), rhs.length());
|
||||||
return lhs_ulen < rhs_ulen;
|
return lhs_ulen < rhs_ulen;
|
||||||
});
|
});
|
||||||
|
|
||||||
auto max_prompt_len = faster_tokenizer::utils::GetUnicodeLenFromUTF8(
|
auto max_prompt_len = faster_tokenizer::utils::GetUnicodeLenFromUTF8(
|
||||||
max_prompt_iter->c_str(), max_prompt_iter->length());
|
max_prompt_iter->c_str(), max_prompt_iter->length());
|
||||||
auto max_predict_len = max_length_ - 3 - max_prompt_len;
|
auto max_predict_len = max_length_ - 3 - max_prompt_len;
|
||||||
@@ -521,6 +541,7 @@ void UIEModel::ConstructTextsAndPrompts(
|
|||||||
}
|
}
|
||||||
(*input_texts) = std::move(short_texts);
|
(*input_texts) = std::move(short_texts);
|
||||||
(*prompts) = std::move(short_texts_prompts);
|
(*prompts) = std::move(short_texts_prompts);
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void UIEModel::Preprocess(
|
void UIEModel::Preprocess(
|
||||||
@@ -542,10 +563,10 @@ void UIEModel::Preprocess(
|
|||||||
if (batch_size > 0) {
|
if (batch_size > 0) {
|
||||||
seq_len = (*encodings)[0].GetIds().size();
|
seq_len = (*encodings)[0].GetIds().size();
|
||||||
}
|
}
|
||||||
inputs->resize(runtime_.NumInputs());
|
inputs->resize(NumInputsOfRuntime());
|
||||||
for (int i = 0; i < runtime_.NumInputs(); ++i) {
|
for (int i = 0; i < NumInputsOfRuntime(); ++i) {
|
||||||
(*inputs)[i].Allocate({batch_size, seq_len}, fastdeploy::FDDataType::INT64,
|
(*inputs)[i].Allocate({batch_size, seq_len}, fastdeploy::FDDataType::INT64,
|
||||||
runtime_.GetInputInfo(i).name);
|
InputInfoOfRuntime(i).name);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2.2 Set the value of data
|
// 2.2 Set the value of data
|
||||||
@@ -619,8 +640,13 @@ void UIEModel::ConstructChildPromptPrefix(
|
|||||||
auto&& input_mapping_item = input_mapping_with_raw_texts[i];
|
auto&& input_mapping_item = input_mapping_with_raw_texts[i];
|
||||||
for (auto&& idx : input_mapping_item) {
|
for (auto&& idx : input_mapping_item) {
|
||||||
for (int j = 0; j < results_list[idx].size(); ++j) {
|
for (int j = 0; j < results_list[idx].size(); ++j) {
|
||||||
// Note(zhoushunjie): It's useful for Chinese model.
|
std::string prefix_str;
|
||||||
auto prefix_str = results_list[idx][j].text_ + "\xe7\x9a\x84";
|
if (schema_language_ == SchemaLanguage::ZH) {
|
||||||
|
// Note(zhoushunjie): It means "of" in Chinese.
|
||||||
|
prefix_str = results_list[idx][j].text_ + "\xe7\x9a\x84";
|
||||||
|
} else {
|
||||||
|
prefix_str = " of " + results_list[idx][j].text_;
|
||||||
|
}
|
||||||
(*prefix)[i].push_back(prefix_str);
|
(*prefix)[i].push_back(prefix_str);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -677,7 +703,7 @@ void UIEModel::ConstructChildRelations(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int i = 0; i < curr_relations.size(); ++i) {
|
for (int i = 0; i < curr_relations.size(); ++i) {
|
||||||
for (int j = 0; j < new_relations[i].size(); ++j) {
|
for (int j = 0; j < curr_relations[i].size(); ++j) {
|
||||||
if (curr_relations[i][j]->relation_.count(node_name)) {
|
if (curr_relations[i][j]->relation_.count(node_name)) {
|
||||||
auto& curr_relation = curr_relations[i][j]->relation_[node_name];
|
auto& curr_relation = curr_relations[i][j]->relation_[node_name];
|
||||||
for (auto&& curr_result_ref : curr_relation) {
|
for (auto&& curr_result_ref : curr_relation) {
|
||||||
@@ -706,24 +732,27 @@ void UIEModel::Predict(
|
|||||||
std::vector<std::string> short_input_texts;
|
std::vector<std::string> short_input_texts;
|
||||||
std::vector<std::string> short_prompts;
|
std::vector<std::string> short_prompts;
|
||||||
// 1. Construct texts and prompts from raw text
|
// 1. Construct texts and prompts from raw text
|
||||||
ConstructTextsAndPrompts(
|
bool has_prompt = ConstructTextsAndPrompts(
|
||||||
texts, node.name_, node.prefix_, &short_input_texts, &short_prompts,
|
texts, node.name_, node.prefix_, &short_input_texts, &short_prompts,
|
||||||
&input_mapping_with_raw_texts, &input_mapping_with_short_text);
|
&input_mapping_with_raw_texts, &input_mapping_with_short_text);
|
||||||
|
|
||||||
// 2. Convert texts and prompts to FDTensor
|
|
||||||
std::vector<FDTensor> inputs;
|
|
||||||
std::vector<faster_tokenizer::core::Encoding> encodings;
|
|
||||||
Preprocess(short_input_texts, short_prompts, &encodings, &inputs);
|
|
||||||
|
|
||||||
// 3. Infer
|
|
||||||
std::vector<fastdeploy::FDTensor> outputs(runtime_.NumOutputs());
|
|
||||||
runtime_.Infer(inputs, &outputs);
|
|
||||||
|
|
||||||
// 4. Convert FDTensor to UIEResult
|
|
||||||
std::vector<std::vector<UIEResult>> results_list;
|
std::vector<std::vector<UIEResult>> results_list;
|
||||||
Postprocess(outputs, encodings, short_input_texts, short_prompts,
|
if (has_prompt) {
|
||||||
input_mapping_with_short_text, &results_list);
|
// 2. Convert texts and prompts to FDTensor
|
||||||
|
std::vector<FDTensor> inputs;
|
||||||
|
std::vector<faster_tokenizer::core::Encoding> encodings;
|
||||||
|
Preprocess(short_input_texts, short_prompts, &encodings, &inputs);
|
||||||
|
|
||||||
|
// 3. Infer
|
||||||
|
std::vector<fastdeploy::FDTensor> outputs(NumOutputsOfRuntime());
|
||||||
|
if (!Infer(inputs, &outputs)) {
|
||||||
|
FDERROR << "Failed to inference while using model:" << ModelName()
|
||||||
|
<< "." << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Convert FDTensor to UIEResult
|
||||||
|
Postprocess(outputs, encodings, short_input_texts, short_prompts,
|
||||||
|
input_mapping_with_short_text, &results_list);
|
||||||
|
}
|
||||||
// 5. Construct the new relation of the UIEResult
|
// 5. Construct the new relation of the UIEResult
|
||||||
std::vector<std::vector<UIEResult*>> relations;
|
std::vector<std::vector<UIEResult*>> relations;
|
||||||
ConstructChildRelations(node.relations_, input_mapping_with_raw_texts,
|
ConstructChildRelations(node.relations_, input_mapping_with_raw_texts,
|
||||||
|
@@ -75,6 +75,11 @@ struct FASTDEPLOY_DECL SchemaNode {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum SchemaLanguage {
|
||||||
|
ZH, // Chinese
|
||||||
|
EN // English
|
||||||
|
};
|
||||||
|
|
||||||
struct Schema {
|
struct Schema {
|
||||||
explicit Schema(const std::string& schema, const std::string& name = "root");
|
explicit Schema(const std::string& schema, const std::string& name = "root");
|
||||||
explicit Schema(const std::vector<std::string>& schema_list,
|
explicit Schema(const std::vector<std::string>& schema_list,
|
||||||
@@ -89,7 +94,7 @@ struct Schema {
|
|||||||
friend class UIEModel;
|
friend class UIEModel;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct FASTDEPLOY_DECL UIEModel {
|
struct FASTDEPLOY_DECL UIEModel : public FastDeployModel {
|
||||||
public:
|
public:
|
||||||
UIEModel(const std::string& model_file, const std::string& params_file,
|
UIEModel(const std::string& model_file, const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
@@ -97,26 +102,30 @@ struct FASTDEPLOY_DECL UIEModel {
|
|||||||
const fastdeploy::RuntimeOption& custom_option =
|
const fastdeploy::RuntimeOption& custom_option =
|
||||||
fastdeploy::RuntimeOption(),
|
fastdeploy::RuntimeOption(),
|
||||||
const fastdeploy::ModelFormat& model_format =
|
const fastdeploy::ModelFormat& model_format =
|
||||||
fastdeploy::ModelFormat::PADDLE);
|
fastdeploy::ModelFormat::PADDLE,
|
||||||
|
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||||
UIEModel(const std::string& model_file, const std::string& params_file,
|
UIEModel(const std::string& model_file, const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const SchemaNode& schema,
|
size_t max_length, const SchemaNode& schema,
|
||||||
const fastdeploy::RuntimeOption& custom_option =
|
const fastdeploy::RuntimeOption& custom_option =
|
||||||
fastdeploy::RuntimeOption(),
|
fastdeploy::RuntimeOption(),
|
||||||
const fastdeploy::ModelFormat& model_format =
|
const fastdeploy::ModelFormat& model_format =
|
||||||
fastdeploy::ModelFormat::PADDLE);
|
fastdeploy::ModelFormat::PADDLE,
|
||||||
|
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||||
UIEModel(const std::string& model_file, const std::string& params_file,
|
UIEModel(const std::string& model_file, const std::string& params_file,
|
||||||
const std::string& vocab_file, float position_prob,
|
const std::string& vocab_file, float position_prob,
|
||||||
size_t max_length, const std::vector<SchemaNode>& schema,
|
size_t max_length, const std::vector<SchemaNode>& schema,
|
||||||
const fastdeploy::RuntimeOption& custom_option =
|
const fastdeploy::RuntimeOption& custom_option =
|
||||||
fastdeploy::RuntimeOption(),
|
fastdeploy::RuntimeOption(),
|
||||||
const fastdeploy::ModelFormat& model_format =
|
const fastdeploy::ModelFormat& model_format =
|
||||||
fastdeploy::ModelFormat::PADDLE);
|
fastdeploy::ModelFormat::PADDLE,
|
||||||
|
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||||
|
virtual std::string ModelName() const { return "UIEModel"; }
|
||||||
void SetSchema(const std::vector<std::string>& schema);
|
void SetSchema(const std::vector<std::string>& schema);
|
||||||
void SetSchema(const std::vector<SchemaNode>& schema);
|
void SetSchema(const std::vector<SchemaNode>& schema);
|
||||||
void SetSchema(const SchemaNode& schema);
|
void SetSchema(const SchemaNode& schema);
|
||||||
|
|
||||||
void ConstructTextsAndPrompts(
|
bool ConstructTextsAndPrompts(
|
||||||
const std::vector<std::string>& raw_texts, const std::string& node_name,
|
const std::vector<std::string>& raw_texts, const std::string& node_name,
|
||||||
const std::vector<std::vector<std::string>> node_prefix,
|
const std::vector<std::vector<std::string>> node_prefix,
|
||||||
std::vector<std::string>* input_texts, std::vector<std::string>* prompts,
|
std::vector<std::string>* input_texts, std::vector<std::string>* prompts,
|
||||||
@@ -150,7 +159,7 @@ struct FASTDEPLOY_DECL UIEModel {
|
|||||||
std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>*
|
std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>*
|
||||||
results);
|
results);
|
||||||
|
|
||||||
private:
|
protected:
|
||||||
using IDX_PROB = std::pair<int64_t, float>;
|
using IDX_PROB = std::pair<int64_t, float>;
|
||||||
struct IdxProbCmp {
|
struct IdxProbCmp {
|
||||||
bool operator()(const std::pair<IDX_PROB, IDX_PROB>& lhs,
|
bool operator()(const std::pair<IDX_PROB, IDX_PROB>& lhs,
|
||||||
@@ -161,6 +170,8 @@ struct FASTDEPLOY_DECL UIEModel {
|
|||||||
faster_tokenizer::core::Offset offset_;
|
faster_tokenizer::core::Offset offset_;
|
||||||
bool is_prompt_;
|
bool is_prompt_;
|
||||||
};
|
};
|
||||||
|
void SetValidBackend();
|
||||||
|
bool Initialize();
|
||||||
void AutoSplitter(const std::vector<std::string>& texts, size_t max_length,
|
void AutoSplitter(const std::vector<std::string>& texts, size_t max_length,
|
||||||
std::vector<std::string>* short_texts,
|
std::vector<std::string>* short_texts,
|
||||||
std::vector<std::vector<size_t>>* input_mapping);
|
std::vector<std::vector<size_t>>* input_mapping);
|
||||||
@@ -185,11 +196,10 @@ struct FASTDEPLOY_DECL UIEModel {
|
|||||||
const std::vector<std::vector<SpanIdx>>& span_idxs,
|
const std::vector<std::vector<SpanIdx>>& span_idxs,
|
||||||
const std::vector<std::vector<float>>& probs,
|
const std::vector<std::vector<float>>& probs,
|
||||||
std::vector<std::vector<UIEResult>>* results) const;
|
std::vector<std::vector<UIEResult>>* results) const;
|
||||||
fastdeploy::RuntimeOption runtime_option_;
|
|
||||||
fastdeploy::Runtime runtime_;
|
|
||||||
std::unique_ptr<Schema> schema_;
|
std::unique_ptr<Schema> schema_;
|
||||||
size_t max_length_;
|
size_t max_length_;
|
||||||
float position_prob_;
|
float position_prob_;
|
||||||
|
SchemaLanguage schema_language_;
|
||||||
faster_tokenizer::tokenizers_impl::ErnieFasterTokenizer tokenizer_;
|
faster_tokenizer::tokenizers_impl::ErnieFasterTokenizer tokenizer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@@ -28,26 +28,34 @@ void BindUIE(pybind11::module& m) {
|
|||||||
.def_readwrite("relations", &text::SchemaNode::relations_)
|
.def_readwrite("relations", &text::SchemaNode::relations_)
|
||||||
.def_readwrite("children", &text::SchemaNode::children_);
|
.def_readwrite("children", &text::SchemaNode::children_);
|
||||||
|
|
||||||
py::class_<text::UIEModel>(m, "UIEModel")
|
py::enum_<text::SchemaLanguage>(m, "SchemaLanguage", py::arithmetic(),
|
||||||
|
"The language of schema.")
|
||||||
|
.value("ZH", text::SchemaLanguage::ZH)
|
||||||
|
.value("EN", text::SchemaLanguage::EN);
|
||||||
|
|
||||||
|
py::class_<text::UIEModel, FastDeployModel>(m, "UIEModel")
|
||||||
.def(py::init<std::string, std::string, std::string, float, size_t,
|
.def(py::init<std::string, std::string, std::string, float, size_t,
|
||||||
std::vector<std::string>, RuntimeOption, ModelFormat>(),
|
std::vector<std::string>, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
|
||||||
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
||||||
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
||||||
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
||||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE)
|
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
||||||
|
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
||||||
.def(
|
.def(
|
||||||
py::init<std::string, std::string, std::string, float, size_t,
|
py::init<std::string, std::string, std::string, float, size_t,
|
||||||
std::vector<text::SchemaNode>, RuntimeOption, ModelFormat>(),
|
std::vector<text::SchemaNode>, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
|
||||||
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
||||||
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
||||||
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
||||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE)
|
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
||||||
|
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
||||||
.def(py::init<std::string, std::string, std::string, float, size_t,
|
.def(py::init<std::string, std::string, std::string, float, size_t,
|
||||||
text::SchemaNode, RuntimeOption, ModelFormat>(),
|
text::SchemaNode, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
|
||||||
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
||||||
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
||||||
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
||||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE)
|
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
||||||
|
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
||||||
.def("set_schema",
|
.def("set_schema",
|
||||||
static_cast<void (text::UIEModel::*)(
|
static_cast<void (text::UIEModel::*)(
|
||||||
const std::vector<std::string>&)>(&text::UIEModel::SetSchema),
|
const std::vector<std::string>&)>(&text::UIEModel::SetSchema),
|
||||||
|
@@ -15,3 +15,4 @@ from __future__ import absolute_import
|
|||||||
|
|
||||||
from . import uie
|
from . import uie
|
||||||
from .uie import UIEModel
|
from .uie import UIEModel
|
||||||
|
from .uie import SchemaLanguage
|
||||||
|
@@ -15,11 +15,14 @@
|
|||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from ... import ModelFormat
|
from ... import RuntimeOption, FastDeployModel, ModelFormat
|
||||||
from ... import RuntimeOption
|
|
||||||
from ... import c_lib_wrap as C
|
from ... import c_lib_wrap as C
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaLanguage(C.text.SchemaLanguage):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SchemaNode(object):
|
class SchemaNode(object):
|
||||||
def __init__(self, name, children=[]):
|
def __init__(self, name, children=[]):
|
||||||
schema_node_children = []
|
schema_node_children = []
|
||||||
@@ -38,7 +41,7 @@ class SchemaNode(object):
|
|||||||
self._schema_node_children = schema_node_children
|
self._schema_node_children = schema_node_children
|
||||||
|
|
||||||
|
|
||||||
class UIEModel(object):
|
class UIEModel(FastDeployModel):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_file,
|
model_file,
|
||||||
params_file,
|
params_file,
|
||||||
@@ -47,7 +50,8 @@ class UIEModel(object):
|
|||||||
max_length=128,
|
max_length=128,
|
||||||
schema=[],
|
schema=[],
|
||||||
runtime_option=RuntimeOption(),
|
runtime_option=RuntimeOption(),
|
||||||
model_format=ModelFormat.PADDLE):
|
model_format=ModelFormat.PADDLE,
|
||||||
|
schema_language=SchemaLanguage.ZH):
|
||||||
if isinstance(schema, list):
|
if isinstance(schema, list):
|
||||||
schema = SchemaNode("", schema)._schema_node_children
|
schema = SchemaNode("", schema)._schema_node_children
|
||||||
elif isinstance(schema, dict):
|
elif isinstance(schema, dict):
|
||||||
@@ -57,9 +61,10 @@ class UIEModel(object):
|
|||||||
schema = schema_tmp
|
schema = schema_tmp
|
||||||
else:
|
else:
|
||||||
assert "The type of schema should be list or dict."
|
assert "The type of schema should be list or dict."
|
||||||
self._model = C.text.UIEModel(model_file, params_file, vocab_file,
|
self._model = C.text.UIEModel(
|
||||||
position_prob, max_length, schema,
|
model_file, params_file, vocab_file, position_prob, max_length,
|
||||||
runtime_option._option, model_format)
|
schema, runtime_option._option, model_format, schema_language)
|
||||||
|
assert self.initialized, "UIEModel initialize failed."
|
||||||
|
|
||||||
def set_schema(self, schema):
|
def set_schema(self, schema):
|
||||||
if isinstance(schema, list):
|
if isinstance(schema, list):
|
||||||
|
Reference in New Issue
Block a user