mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Model] Upgrade uie (#458)
* Upgrade uie c++ implement * upgrade python UIEModel inherit FastDeployModel * Add schema language support; Skip infer when no prompts * Adjust the schema language arg pos * Add schema_language for python and cpp * update pybind for uie * Fix the args of uie * Add SchemaLanguage
This commit is contained in:
@@ -15,3 +15,4 @@ from __future__ import absolute_import
|
||||
|
||||
from . import uie
|
||||
from .uie import UIEModel
|
||||
from .uie import SchemaLanguage
|
||||
|
@@ -15,11 +15,14 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
import logging
|
||||
from ... import ModelFormat
|
||||
from ... import RuntimeOption
|
||||
from ... import RuntimeOption, FastDeployModel, ModelFormat
|
||||
from ... import c_lib_wrap as C
|
||||
|
||||
|
||||
class SchemaLanguage(C.text.SchemaLanguage):
|
||||
pass
|
||||
|
||||
|
||||
class SchemaNode(object):
|
||||
def __init__(self, name, children=[]):
|
||||
schema_node_children = []
|
||||
@@ -38,7 +41,7 @@ class SchemaNode(object):
|
||||
self._schema_node_children = schema_node_children
|
||||
|
||||
|
||||
class UIEModel(object):
|
||||
class UIEModel(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file,
|
||||
@@ -47,7 +50,8 @@ class UIEModel(object):
|
||||
max_length=128,
|
||||
schema=[],
|
||||
runtime_option=RuntimeOption(),
|
||||
model_format=ModelFormat.PADDLE):
|
||||
model_format=ModelFormat.PADDLE,
|
||||
schema_language=SchemaLanguage.ZH):
|
||||
if isinstance(schema, list):
|
||||
schema = SchemaNode("", schema)._schema_node_children
|
||||
elif isinstance(schema, dict):
|
||||
@@ -57,9 +61,10 @@ class UIEModel(object):
|
||||
schema = schema_tmp
|
||||
else:
|
||||
assert "The type of schema should be list or dict."
|
||||
self._model = C.text.UIEModel(model_file, params_file, vocab_file,
|
||||
position_prob, max_length, schema,
|
||||
runtime_option._option, model_format)
|
||||
self._model = C.text.UIEModel(
|
||||
model_file, params_file, vocab_file, position_prob, max_length,
|
||||
schema, runtime_option._option, model_format, schema_language)
|
||||
assert self.initialized, "UIEModel initialize failed."
|
||||
|
||||
def set_schema(self, schema):
|
||||
if isinstance(schema, list):
|
||||
|
Reference in New Issue
Block a user