mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-08 10:00:29 +08:00
[WIP]Add VI-LayoutXLM (#2048)
* WIP, add VI-LayoutXLM * fix pybind * update the dir of ser_vi_layoutxlm model * update dir and name of ser_vi_layoutxlm model * update model name to StructureV2SerViLayoutXLMModel * fix import paddle bug --------- Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,11 @@ from .... import FastDeployModel, ModelFormat
|
||||
from .... import c_lib_wrap as C
|
||||
from ...common import ProcessorManager
|
||||
|
||||
from .utils.ser_vi_layoutxlm.vqa_utils import *
|
||||
from .utils.ser_vi_layoutxlm.transforms import *
|
||||
from .utils.ser_vi_layoutxlm.operators import *
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def sort_boxes(boxes):
|
||||
return C.vision.ocr.sort_boxes(boxes)
|
||||
@@ -848,6 +853,7 @@ class StructureV2Layout(FastDeployModel):
|
||||
def postprocessor(self, value):
|
||||
self._model.postprocessor = value
|
||||
|
||||
|
||||
class PPOCRv4(FastDeployModel):
|
||||
def __init__(self, det_model=None, cls_model=None, rec_model=None):
|
||||
"""Consruct a pipeline with text detector, direction classifier and text recognizer models
|
||||
@@ -912,6 +918,7 @@ class PPOCRv4(FastDeployModel):
|
||||
int), "The value to set `rec_batch_size` must be type of int."
|
||||
self.system_.rec_batch_size = value
|
||||
|
||||
|
||||
class PPOCRSystemv4(PPOCRv4):
|
||||
def __init__(self, det_model=None, cls_model=None, rec_model=None):
|
||||
logging.warning(
|
||||
@@ -922,6 +929,7 @@ class PPOCRSystemv4(PPOCRv4):
|
||||
def predict(self, input_image):
|
||||
return super(PPOCRSystemv4, self).predict(input_image)
|
||||
|
||||
|
||||
class PPOCRv3(FastDeployModel):
|
||||
def __init__(self, det_model=None, cls_model=None, rec_model=None):
|
||||
"""Consruct a pipeline with text detector, direction classifier and text recognizer models
|
||||
@@ -1129,3 +1137,201 @@ class PPStructureV2TableSystem(PPStructureV2Table):
|
||||
|
||||
def predict(self, input_image):
|
||||
return super(PPStructureV2TableSystem, self).predict(input_image)
|
||||
|
||||
|
||||
class StructureV2SERViLayoutXLMModelPreprocessor():
|
||||
def __init__(self, ser_dict_path, use_gpu=True):
|
||||
"""Create a preprocessor for Ser-Vi-LayoutXLM model.
|
||||
:param: ser_dict_path: (str) class file path
|
||||
:param: use_gpu: (bool) whether use gpu to OCR process
|
||||
"""
|
||||
self._manager = None
|
||||
from paddleocr import PaddleOCR
|
||||
self.ocr_engine = PaddleOCR(
|
||||
use_angle_cls=False,
|
||||
det_model_dir=None,
|
||||
rec_model_dir=None,
|
||||
show_log=False,
|
||||
use_gpu=use_gpu)
|
||||
|
||||
pre_process_list = [{
|
||||
'VQATokenLabelEncode': {
|
||||
'class_path': ser_dict_path,
|
||||
'contains_re': False,
|
||||
'ocr_engine': self.ocr_engine,
|
||||
'order_method': "tb-yx"
|
||||
}
|
||||
}, {
|
||||
'VQATokenPad': {
|
||||
'max_seq_len': 512,
|
||||
'return_attention_mask': True
|
||||
}
|
||||
}, {
|
||||
'VQASerTokenChunk': {
|
||||
'max_seq_len': 512,
|
||||
'return_attention_mask': True
|
||||
}
|
||||
}, {
|
||||
'Resize': {
|
||||
'size': [224, 224]
|
||||
}
|
||||
}, {
|
||||
'NormalizeImage': {
|
||||
'std': [58.395, 57.12, 57.375],
|
||||
'mean': [123.675, 116.28, 103.53],
|
||||
'scale': '1',
|
||||
'order': 'hwc'
|
||||
}
|
||||
}, {
|
||||
'ToCHWImage': None
|
||||
}, {
|
||||
'KeepKeys': {
|
||||
'keep_keys': [
|
||||
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
|
||||
'image', 'labels', 'segment_offset_id', 'ocr_info',
|
||||
'entities'
|
||||
]
|
||||
}
|
||||
}]
|
||||
|
||||
self.preprocess_op = create_operators(pre_process_list,
|
||||
{'infer_mode': True})
|
||||
|
||||
def _transform(self, data, ops=None):
|
||||
""" transform """
|
||||
if ops is None:
|
||||
ops = []
|
||||
for op in ops:
|
||||
data = op(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
|
||||
def run(self, input_im):
|
||||
"""Run preprocess of Ser-Vi-LayoutXLM model
|
||||
:param: input_ims: (numpy.ndarray) input image
|
||||
"""
|
||||
ori_im = input_im.copy()
|
||||
data = {'image': input_im}
|
||||
data = transform(data, self.preprocess_op)
|
||||
|
||||
for idx in range(len(data)):
|
||||
if isinstance(data[idx], np.ndarray):
|
||||
data[idx] = np.expand_dims(data[idx], axis=0)
|
||||
else:
|
||||
data[idx] = [data[idx]]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class StructureV2SERViLayoutXLMModelPostprocessor():
|
||||
def __init__(self, class_path):
|
||||
"""Create a postprocessor for Ser-Vi-LayoutXLM model.
|
||||
:param: class_path: (string) class file path
|
||||
"""
|
||||
self.postprocessor_op = VQASerTokenLayoutLMPostProcess(class_path)
|
||||
|
||||
def run(self, preds, batch=None, *args, **kwargs):
|
||||
"""Run postprocess of Ser-Vi-LayoutXLM model.
|
||||
:param: preds: (list) results of infering
|
||||
"""
|
||||
return self.postprocessor_op(preds, batch, *args, **kwargs)
|
||||
|
||||
|
||||
class StructureV2SERViLayoutXLMModel(FastDeployModel):
|
||||
def __init__(self,
|
||||
model_file,
|
||||
params_file,
|
||||
ser_dict_path,
|
||||
class_path,
|
||||
config_file="",
|
||||
runtime_option=None,
|
||||
model_format=ModelFormat.PADDLE):
|
||||
"""Load SERViLayoutXLM model provided by PP-StructureV2.
|
||||
|
||||
:param model_file: (str)Path of model file, e.g ./ser_vi_layout_xlm/model.pdmodel.
|
||||
:param params_file: (str)Path of parameter file, e.g ./ser_vi_layout_xlm/model.pdiparams, if the model format is ONNX, this parameter will be ignored.
|
||||
:param ser_dict_path: (str) class file path
|
||||
:param class_path: (str) class file path
|
||||
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU.
|
||||
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model.
|
||||
"""
|
||||
super(StructureV2SERViLayoutXLMModel, self).__init__(runtime_option)
|
||||
|
||||
assert self._runtime_option.backend != 0, \
|
||||
"Runtime Option required backend setting."
|
||||
self._model = C.vision.ocr.StructureV2SERViLayoutXLMModel(
|
||||
model_file, params_file, config_file, self._runtime_option,
|
||||
model_format)
|
||||
|
||||
assert self.initialized, "SERViLayoutXLM model initialize failed."
|
||||
|
||||
self.preprocessor = StructureV2SERViLayoutXLMModelPreprocessor(
|
||||
ser_dict_path)
|
||||
self.postprocesser = StructureV2SERViLayoutXLMModelPostprocessor(
|
||||
class_path)
|
||||
|
||||
self.input_name_0 = self._model.get_input_info(0).name
|
||||
self.input_name_1 = self._model.get_input_info(1).name
|
||||
self.input_name_2 = self._model.get_input_info(2).name
|
||||
self.input_name_3 = self._model.get_input_info(3).name
|
||||
|
||||
def predict(self, image):
|
||||
assert isinstance(image,
|
||||
np.ndarray), "predict recives numpy.ndarray(BGR)"
|
||||
|
||||
data = self.preprocessor.run(image)
|
||||
infer_input = {
|
||||
self.input_name_0: data[0],
|
||||
self.input_name_1: data[1],
|
||||
self.input_name_2: data[2],
|
||||
self.input_name_3: data[3],
|
||||
}
|
||||
|
||||
infer_result = self._model.infer(infer_input)
|
||||
infer_result = infer_result[0]
|
||||
|
||||
post_result = self.postprocesser.run(infer_result,
|
||||
segment_offset_ids=data[6],
|
||||
ocr_infos=data[7])
|
||||
|
||||
return post_result
|
||||
|
||||
def batch_predict(self, image_list):
|
||||
assert isinstance(image_list, list) and \
|
||||
isinstance(image_list[0], np.ndarray), \
|
||||
"batch_predict recives list of numpy.ndarray(BGR)"
|
||||
|
||||
# reading and preprocessing images
|
||||
datas = None
|
||||
for image in image_list:
|
||||
data = self.preprocessor.run(image)
|
||||
|
||||
# concatenate data to batch
|
||||
if datas == None:
|
||||
datas = data
|
||||
else:
|
||||
for idx in range(len(data)):
|
||||
if isinstance(data[idx], np.ndarray):
|
||||
datas[idx] = np.concatenate(
|
||||
(datas[idx], data[idx]), axis=0)
|
||||
else:
|
||||
datas[idx].extend(data[idx])
|
||||
|
||||
# infer
|
||||
infer_inputs = {
|
||||
self.input_name_0: datas[0],
|
||||
self.input_name_1: datas[1],
|
||||
self.input_name_2: datas[2],
|
||||
self.input_name_3: datas[3],
|
||||
}
|
||||
|
||||
infer_results = self._model.infer(infer_inputs)
|
||||
infer_results = infer_results[0]
|
||||
|
||||
# postprocessing
|
||||
post_results = self.postprocesser.run(infer_results,
|
||||
segment_offset_ids=datas[6],
|
||||
ocr_infos=datas[7])
|
||||
|
||||
return post_results
|
||||
|
Reference in New Issue
Block a user