mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Model] Upgrade uie (#458)
* Upgrade uie c++ implement * upgrade python UIEModel inherit FastDeployModel * Add schema language support; Skip infer when no prompts * Adjust the schema language arg pos * Add schema_language for python and cpp * update pybind for uie * Fix the args of uie * Add SchemaLanguage
This commit is contained in:
@@ -411,20 +411,24 @@ UIEModel(
|
||||
const std::vector<std::string>& schema,
|
||||
const fastdeploy::RuntimeOption& custom_option =
|
||||
fastdeploy::RuntimeOption(),
|
||||
const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE);
|
||||
const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE,
|
||||
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||
UIEModel(
|
||||
const std::string& model_file, const std::string& params_file,
|
||||
const std::string& vocab_file, float position_prob, size_t max_length,
|
||||
const SchemaNode& schema, const fastdeploy::RuntimeOption& custom_option =
|
||||
fastdeploy::RuntimeOption(),
|
||||
const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE);
|
||||
const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE,
|
||||
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||
UIEModel(
|
||||
const std::string& model_file, const std::string& params_file,
|
||||
const std::string& vocab_file, float position_prob, size_t max_length,
|
||||
const std::vector<SchemaNode>& schema,
|
||||
const fastdeploy::RuntimeOption& custom_option =
|
||||
fastdeploy::RuntimeOption(),
|
||||
const fastdeploy::ModelFormat& model_format = fastdeploy::ModelFormat::PADDLE);
|
||||
const fastdeploy::ModelFormat& model_format =
|
||||
fastdeploy::ModelFormat::PADDLE,
|
||||
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||
```
|
||||
|
||||
UIE模型加载和初始化,其中model_file, params_file为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/model_zoo/uie/README.md#%E6%A8%A1%E5%9E%8B%E9%83%A8%E7%BD%B2)。
|
||||
@@ -439,6 +443,7 @@ UIE模型加载和初始化,其中model_file, params_file为训练模型导出
|
||||
> * **schema**(list(SchemaNode) | SchemaNode | list(str)): 抽取任务的目标模式。
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式
|
||||
> * **schema_language** (SchemaLanguage): Schema 语言,默认为ZH(中文),目前支持的语言种类包括:ZH(中文),EN(英文)。
|
||||
|
||||
#### SetSchema函数
|
||||
|
||||
|
@@ -329,7 +329,9 @@ fd.text.uie.UIEModel(model_file,
|
||||
position_prob=0.5,
|
||||
max_length=128,
|
||||
schema=[],
|
||||
runtime_option=None,model_format=ModelFormat.PADDLE)
|
||||
runtime_option=None,
|
||||
model_format=ModelFormat.PADDLE,
|
||||
schema_language=SchemaLanguage.ZH)
|
||||
```
|
||||
|
||||
UIEModel模型加载和初始化,其中`model_file`, `params_file`为训练模型导出的Paddle inference文件,具体请参考其文档说明[模型导出](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/model_zoo/uie/README.md#%E6%A8%A1%E5%9E%8B%E9%83%A8%E7%BD%B2),`vocab_file`为词表文件,UIE模型的词表可在[UIE配置文件](https://github.com/PaddlePaddle/PaddleNLP/blob/5401f01af85f1c73d8017c6b3476242fce1e6d52/model_zoo/uie/utils.py)中下载相应的UIE模型的vocab_file。
|
||||
@@ -344,6 +346,7 @@ UIEModel模型加载和初始化,其中`model_file`, `params_file`为训练模
|
||||
> * **schema**(list|dict): 抽取任务的目标信息。
|
||||
> * **runtime_option**(RuntimeOption): 后端推理配置,默认为None,即采用默认配置
|
||||
> * **model_format**(ModelFormat): 模型格式,默认为Paddle格式
|
||||
> * **schema_language**(SchemaLanguage): Schema语言。默认为ZH(中文),目前支持的语言种类包括:ZH(中文),EN(英文)。
|
||||
|
||||
### set_schema函数
|
||||
|
||||
|
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import fastdeploy
|
||||
from fastdeploy.text import UIEModel
|
||||
from fastdeploy.text import UIEModel, SchemaLanguage
|
||||
import os
|
||||
from pprint import pprint
|
||||
|
||||
@@ -80,7 +80,8 @@ if __name__ == "__main__":
|
||||
position_prob=0.5,
|
||||
max_length=128,
|
||||
schema=schema,
|
||||
runtime_option=runtime_option)
|
||||
runtime_option=runtime_option,
|
||||
schema_language=SchemaLanguage.ZH)
|
||||
|
||||
print("1. Named Entity Recognition Task")
|
||||
print(f"The extraction schema: {schema}")
|
||||
|
@@ -165,14 +165,16 @@ UIEModel::UIEModel(const std::string& model_file,
|
||||
const std::string& vocab_file, float position_prob,
|
||||
size_t max_length, const std::vector<std::string>& schema,
|
||||
const fastdeploy::RuntimeOption& custom_option,
|
||||
const fastdeploy::ModelFormat& model_format)
|
||||
const fastdeploy::ModelFormat& model_format,
|
||||
SchemaLanguage schema_language)
|
||||
: max_length_(max_length),
|
||||
position_prob_(position_prob),
|
||||
schema_language_(schema_language),
|
||||
tokenizer_(vocab_file) {
|
||||
runtime_option_ = custom_option;
|
||||
runtime_option_.model_format = model_format;
|
||||
runtime_option_.SetModelPath(model_file, params_file);
|
||||
runtime_.Init(runtime_option_);
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.SetModelPath(model_file, params_file);
|
||||
initialized = Initialize();
|
||||
SetSchema(schema);
|
||||
tokenizer_.EnableTruncMethod(
|
||||
max_length, 0, faster_tokenizer::core::Direction::RIGHT,
|
||||
@@ -184,14 +186,16 @@ UIEModel::UIEModel(const std::string& model_file,
|
||||
const std::string& vocab_file, float position_prob,
|
||||
size_t max_length, const std::vector<SchemaNode>& schema,
|
||||
const fastdeploy::RuntimeOption& custom_option,
|
||||
const fastdeploy::ModelFormat& model_format)
|
||||
const fastdeploy::ModelFormat& model_format,
|
||||
SchemaLanguage schema_language)
|
||||
: max_length_(max_length),
|
||||
position_prob_(position_prob),
|
||||
schema_language_(schema_language),
|
||||
tokenizer_(vocab_file) {
|
||||
runtime_option_ = custom_option;
|
||||
runtime_option_.model_format = model_format;
|
||||
runtime_option_.SetModelPath(model_file, params_file);
|
||||
runtime_.Init(runtime_option_);
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.SetModelPath(model_file, params_file);
|
||||
initialized = Initialize();
|
||||
SetSchema(schema);
|
||||
tokenizer_.EnableTruncMethod(
|
||||
max_length, 0, faster_tokenizer::core::Direction::RIGHT,
|
||||
@@ -203,20 +207,33 @@ UIEModel::UIEModel(const std::string& model_file,
|
||||
const std::string& vocab_file, float position_prob,
|
||||
size_t max_length, const SchemaNode& schema,
|
||||
const fastdeploy::RuntimeOption& custom_option,
|
||||
const fastdeploy::ModelFormat& model_format)
|
||||
const fastdeploy::ModelFormat& model_format,
|
||||
SchemaLanguage schema_language)
|
||||
: max_length_(max_length),
|
||||
position_prob_(position_prob),
|
||||
schema_language_(schema_language),
|
||||
tokenizer_(vocab_file) {
|
||||
runtime_option_ = custom_option;
|
||||
runtime_option_.model_format = model_format;
|
||||
runtime_option_.SetModelPath(model_file, params_file);
|
||||
runtime_.Init(runtime_option_);
|
||||
runtime_option = custom_option;
|
||||
runtime_option.model_format = model_format;
|
||||
runtime_option.SetModelPath(model_file, params_file);
|
||||
initialized = Initialize();
|
||||
SetSchema(schema);
|
||||
tokenizer_.EnableTruncMethod(
|
||||
max_length, 0, faster_tokenizer::core::Direction::RIGHT,
|
||||
faster_tokenizer::core::TruncStrategy::LONGEST_FIRST);
|
||||
}
|
||||
|
||||
bool UIEModel::Initialize() {
|
||||
SetValidBackend();
|
||||
return InitRuntime();
|
||||
}
|
||||
|
||||
void UIEModel::SetValidBackend() {
|
||||
// TODO(zhoushunjie): Add lite backend in future
|
||||
valid_cpu_backends = {Backend::ORT, Backend::OPENVINO, Backend::PDINFER};
|
||||
valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT};
|
||||
}
|
||||
|
||||
void UIEModel::SetSchema(const std::vector<std::string>& schema) {
|
||||
schema_ = fastdeploy::utils::make_unique<Schema>(schema);
|
||||
}
|
||||
@@ -463,7 +480,7 @@ void UIEModel::AutoJoiner(const std::vector<std::string>& short_texts,
|
||||
*results = std::move(final_result);
|
||||
}
|
||||
|
||||
void UIEModel::ConstructTextsAndPrompts(
|
||||
bool UIEModel::ConstructTextsAndPrompts(
|
||||
const std::vector<std::string>& raw_texts, const std::string& node_name,
|
||||
const std::vector<std::vector<std::string>> node_prefix,
|
||||
std::vector<std::string>* input_texts, std::vector<std::string>* prompts,
|
||||
@@ -496,6 +513,10 @@ void UIEModel::ConstructTextsAndPrompts(
|
||||
}
|
||||
}
|
||||
|
||||
if (prompts->size() == 0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Shortten the input texts and prompts
|
||||
auto max_prompt_iter = std::max_element(
|
||||
prompts->begin(), prompts->end(),
|
||||
@@ -506,7 +527,6 @@ void UIEModel::ConstructTextsAndPrompts(
|
||||
rhs.c_str(), rhs.length());
|
||||
return lhs_ulen < rhs_ulen;
|
||||
});
|
||||
|
||||
auto max_prompt_len = faster_tokenizer::utils::GetUnicodeLenFromUTF8(
|
||||
max_prompt_iter->c_str(), max_prompt_iter->length());
|
||||
auto max_predict_len = max_length_ - 3 - max_prompt_len;
|
||||
@@ -521,6 +541,7 @@ void UIEModel::ConstructTextsAndPrompts(
|
||||
}
|
||||
(*input_texts) = std::move(short_texts);
|
||||
(*prompts) = std::move(short_texts_prompts);
|
||||
return true;
|
||||
}
|
||||
|
||||
void UIEModel::Preprocess(
|
||||
@@ -542,10 +563,10 @@ void UIEModel::Preprocess(
|
||||
if (batch_size > 0) {
|
||||
seq_len = (*encodings)[0].GetIds().size();
|
||||
}
|
||||
inputs->resize(runtime_.NumInputs());
|
||||
for (int i = 0; i < runtime_.NumInputs(); ++i) {
|
||||
inputs->resize(NumInputsOfRuntime());
|
||||
for (int i = 0; i < NumInputsOfRuntime(); ++i) {
|
||||
(*inputs)[i].Allocate({batch_size, seq_len}, fastdeploy::FDDataType::INT64,
|
||||
runtime_.GetInputInfo(i).name);
|
||||
InputInfoOfRuntime(i).name);
|
||||
}
|
||||
|
||||
// 2.2 Set the value of data
|
||||
@@ -619,8 +640,13 @@ void UIEModel::ConstructChildPromptPrefix(
|
||||
auto&& input_mapping_item = input_mapping_with_raw_texts[i];
|
||||
for (auto&& idx : input_mapping_item) {
|
||||
for (int j = 0; j < results_list[idx].size(); ++j) {
|
||||
// Note(zhoushunjie): It's useful for Chinese model.
|
||||
auto prefix_str = results_list[idx][j].text_ + "\xe7\x9a\x84";
|
||||
std::string prefix_str;
|
||||
if (schema_language_ == SchemaLanguage::ZH) {
|
||||
// Note(zhoushunjie): It means "of" in Chinese.
|
||||
prefix_str = results_list[idx][j].text_ + "\xe7\x9a\x84";
|
||||
} else {
|
||||
prefix_str = " of " + results_list[idx][j].text_;
|
||||
}
|
||||
(*prefix)[i].push_back(prefix_str);
|
||||
}
|
||||
}
|
||||
@@ -677,7 +703,7 @@ void UIEModel::ConstructChildRelations(
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < curr_relations.size(); ++i) {
|
||||
for (int j = 0; j < new_relations[i].size(); ++j) {
|
||||
for (int j = 0; j < curr_relations[i].size(); ++j) {
|
||||
if (curr_relations[i][j]->relation_.count(node_name)) {
|
||||
auto& curr_relation = curr_relations[i][j]->relation_[node_name];
|
||||
for (auto&& curr_result_ref : curr_relation) {
|
||||
@@ -706,24 +732,27 @@ void UIEModel::Predict(
|
||||
std::vector<std::string> short_input_texts;
|
||||
std::vector<std::string> short_prompts;
|
||||
// 1. Construct texts and prompts from raw text
|
||||
ConstructTextsAndPrompts(
|
||||
bool has_prompt = ConstructTextsAndPrompts(
|
||||
texts, node.name_, node.prefix_, &short_input_texts, &short_prompts,
|
||||
&input_mapping_with_raw_texts, &input_mapping_with_short_text);
|
||||
|
||||
// 2. Convert texts and prompts to FDTensor
|
||||
std::vector<FDTensor> inputs;
|
||||
std::vector<faster_tokenizer::core::Encoding> encodings;
|
||||
Preprocess(short_input_texts, short_prompts, &encodings, &inputs);
|
||||
|
||||
// 3. Infer
|
||||
std::vector<fastdeploy::FDTensor> outputs(runtime_.NumOutputs());
|
||||
runtime_.Infer(inputs, &outputs);
|
||||
|
||||
// 4. Convert FDTensor to UIEResult
|
||||
std::vector<std::vector<UIEResult>> results_list;
|
||||
Postprocess(outputs, encodings, short_input_texts, short_prompts,
|
||||
input_mapping_with_short_text, &results_list);
|
||||
if (has_prompt) {
|
||||
// 2. Convert texts and prompts to FDTensor
|
||||
std::vector<FDTensor> inputs;
|
||||
std::vector<faster_tokenizer::core::Encoding> encodings;
|
||||
Preprocess(short_input_texts, short_prompts, &encodings, &inputs);
|
||||
|
||||
// 3. Infer
|
||||
std::vector<fastdeploy::FDTensor> outputs(NumOutputsOfRuntime());
|
||||
if (!Infer(inputs, &outputs)) {
|
||||
FDERROR << "Failed to inference while using model:" << ModelName()
|
||||
<< "." << std::endl;
|
||||
}
|
||||
|
||||
// 4. Convert FDTensor to UIEResult
|
||||
Postprocess(outputs, encodings, short_input_texts, short_prompts,
|
||||
input_mapping_with_short_text, &results_list);
|
||||
}
|
||||
// 5. Construct the new relation of the UIEResult
|
||||
std::vector<std::vector<UIEResult*>> relations;
|
||||
ConstructChildRelations(node.relations_, input_mapping_with_raw_texts,
|
||||
|
@@ -75,6 +75,11 @@ struct FASTDEPLOY_DECL SchemaNode {
|
||||
}
|
||||
};
|
||||
|
||||
enum SchemaLanguage {
|
||||
ZH, // Chinese
|
||||
EN // English
|
||||
};
|
||||
|
||||
struct Schema {
|
||||
explicit Schema(const std::string& schema, const std::string& name = "root");
|
||||
explicit Schema(const std::vector<std::string>& schema_list,
|
||||
@@ -89,7 +94,7 @@ struct Schema {
|
||||
friend class UIEModel;
|
||||
};
|
||||
|
||||
struct FASTDEPLOY_DECL UIEModel {
|
||||
struct FASTDEPLOY_DECL UIEModel : public FastDeployModel {
|
||||
public:
|
||||
UIEModel(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& vocab_file, float position_prob,
|
||||
@@ -97,26 +102,30 @@ struct FASTDEPLOY_DECL UIEModel {
|
||||
const fastdeploy::RuntimeOption& custom_option =
|
||||
fastdeploy::RuntimeOption(),
|
||||
const fastdeploy::ModelFormat& model_format =
|
||||
fastdeploy::ModelFormat::PADDLE);
|
||||
fastdeploy::ModelFormat::PADDLE,
|
||||
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||
UIEModel(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& vocab_file, float position_prob,
|
||||
size_t max_length, const SchemaNode& schema,
|
||||
const fastdeploy::RuntimeOption& custom_option =
|
||||
fastdeploy::RuntimeOption(),
|
||||
const fastdeploy::ModelFormat& model_format =
|
||||
fastdeploy::ModelFormat::PADDLE);
|
||||
fastdeploy::ModelFormat::PADDLE,
|
||||
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||
UIEModel(const std::string& model_file, const std::string& params_file,
|
||||
const std::string& vocab_file, float position_prob,
|
||||
size_t max_length, const std::vector<SchemaNode>& schema,
|
||||
const fastdeploy::RuntimeOption& custom_option =
|
||||
fastdeploy::RuntimeOption(),
|
||||
const fastdeploy::ModelFormat& model_format =
|
||||
fastdeploy::ModelFormat::PADDLE);
|
||||
fastdeploy::ModelFormat::PADDLE,
|
||||
SchemaLanguage schema_language = SchemaLanguage::ZH);
|
||||
virtual std::string ModelName() const { return "UIEModel"; }
|
||||
void SetSchema(const std::vector<std::string>& schema);
|
||||
void SetSchema(const std::vector<SchemaNode>& schema);
|
||||
void SetSchema(const SchemaNode& schema);
|
||||
|
||||
void ConstructTextsAndPrompts(
|
||||
bool ConstructTextsAndPrompts(
|
||||
const std::vector<std::string>& raw_texts, const std::string& node_name,
|
||||
const std::vector<std::vector<std::string>> node_prefix,
|
||||
std::vector<std::string>* input_texts, std::vector<std::string>* prompts,
|
||||
@@ -150,7 +159,7 @@ struct FASTDEPLOY_DECL UIEModel {
|
||||
std::vector<std::unordered_map<std::string, std::vector<UIEResult>>>*
|
||||
results);
|
||||
|
||||
private:
|
||||
protected:
|
||||
using IDX_PROB = std::pair<int64_t, float>;
|
||||
struct IdxProbCmp {
|
||||
bool operator()(const std::pair<IDX_PROB, IDX_PROB>& lhs,
|
||||
@@ -161,6 +170,8 @@ struct FASTDEPLOY_DECL UIEModel {
|
||||
faster_tokenizer::core::Offset offset_;
|
||||
bool is_prompt_;
|
||||
};
|
||||
void SetValidBackend();
|
||||
bool Initialize();
|
||||
void AutoSplitter(const std::vector<std::string>& texts, size_t max_length,
|
||||
std::vector<std::string>* short_texts,
|
||||
std::vector<std::vector<size_t>>* input_mapping);
|
||||
@@ -185,11 +196,10 @@ struct FASTDEPLOY_DECL UIEModel {
|
||||
const std::vector<std::vector<SpanIdx>>& span_idxs,
|
||||
const std::vector<std::vector<float>>& probs,
|
||||
std::vector<std::vector<UIEResult>>* results) const;
|
||||
fastdeploy::RuntimeOption runtime_option_;
|
||||
fastdeploy::Runtime runtime_;
|
||||
std::unique_ptr<Schema> schema_;
|
||||
size_t max_length_;
|
||||
float position_prob_;
|
||||
SchemaLanguage schema_language_;
|
||||
faster_tokenizer::tokenizers_impl::ErnieFasterTokenizer tokenizer_;
|
||||
};
|
||||
|
||||
|
@@ -28,26 +28,34 @@ void BindUIE(pybind11::module& m) {
|
||||
.def_readwrite("relations", &text::SchemaNode::relations_)
|
||||
.def_readwrite("children", &text::SchemaNode::children_);
|
||||
|
||||
py::class_<text::UIEModel>(m, "UIEModel")
|
||||
py::enum_<text::SchemaLanguage>(m, "SchemaLanguage", py::arithmetic(),
|
||||
"The language of schema.")
|
||||
.value("ZH", text::SchemaLanguage::ZH)
|
||||
.value("EN", text::SchemaLanguage::EN);
|
||||
|
||||
py::class_<text::UIEModel, FastDeployModel>(m, "UIEModel")
|
||||
.def(py::init<std::string, std::string, std::string, float, size_t,
|
||||
std::vector<std::string>, RuntimeOption, ModelFormat>(),
|
||||
std::vector<std::string>, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
|
||||
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
||||
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
||||
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE)
|
||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
||||
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
||||
.def(
|
||||
py::init<std::string, std::string, std::string, float, size_t,
|
||||
std::vector<text::SchemaNode>, RuntimeOption, ModelFormat>(),
|
||||
std::vector<text::SchemaNode>, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
|
||||
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
||||
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
||||
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE)
|
||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
||||
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
||||
.def(py::init<std::string, std::string, std::string, float, size_t,
|
||||
text::SchemaNode, RuntimeOption, ModelFormat>(),
|
||||
text::SchemaNode, RuntimeOption, ModelFormat, text::SchemaLanguage>(),
|
||||
py::arg("model_file"), py::arg("params_file"), py::arg("vocab_file"),
|
||||
py::arg("position_prob"), py::arg("max_length"), py::arg("schema"),
|
||||
py::arg("custom_option") = fastdeploy::RuntimeOption(),
|
||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE)
|
||||
py::arg("model_format") = fastdeploy::ModelFormat::PADDLE,
|
||||
py::arg("schema_language") = text::SchemaLanguage::ZH)
|
||||
.def("set_schema",
|
||||
static_cast<void (text::UIEModel::*)(
|
||||
const std::vector<std::string>&)>(&text::UIEModel::SetSchema),
|
||||
|
@@ -15,3 +15,4 @@ from __future__ import absolute_import
|
||||
|
||||
from . import uie
|
||||
from .uie import UIEModel
|
||||
from .uie import SchemaLanguage
|
||||
|
@@ -15,11 +15,14 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import logging
|
||||
from ... import ModelFormat
|
||||
from ... import RuntimeOption
|
||||
from ... import RuntimeOption, FastDeployModel, ModelFormat
|
||||
from ... import c_lib_wrap as C
|
||||
|
||||
|
||||
class SchemaLanguage(C.text.SchemaLanguage):
|
||||
pass
|
||||
|
||||
|
||||
class SchemaNode(object):
|
||||
def __init__(self, name, children=[]):
|
||||
schema_node_children = []
|
||||
@@ -38,7 +41,7 @@ class SchemaNode(object):
|
||||
self._schema_node_children = schema_node_children
|
||||
|
||||
|
||||
class UIEModel(object):
|
||||
class UIEModel(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file,
|
||||
@@ -47,7 +50,8 @@ class UIEModel(object):
|
||||
max_length=128,
|
||||
schema=[],
|
||||
runtime_option=RuntimeOption(),
|
||||
model_format=ModelFormat.PADDLE):
|
||||
model_format=ModelFormat.PADDLE,
|
||||
schema_language=SchemaLanguage.ZH):
|
||||
if isinstance(schema, list):
|
||||
schema = SchemaNode("", schema)._schema_node_children
|
||||
elif isinstance(schema, dict):
|
||||
@@ -57,9 +61,10 @@ class UIEModel(object):
|
||||
schema = schema_tmp
|
||||
else:
|
||||
assert "The type of schema should be list or dict."
|
||||
self._model = C.text.UIEModel(model_file, params_file, vocab_file,
|
||||
position_prob, max_length, schema,
|
||||
runtime_option._option, model_format)
|
||||
self._model = C.text.UIEModel(
|
||||
model_file, params_file, vocab_file, position_prob, max_length,
|
||||
schema, runtime_option._option, model_format, schema_language)
|
||||
assert self.initialized, "UIEModel initialize failed."
|
||||
|
||||
def set_schema(self, schema):
|
||||
if isinstance(schema, list):
|
||||
|
Reference in New Issue
Block a user