[WIP]Add VI-LayoutXLM (#2048)

* WIP, add VI-LayoutXLM

* fix pybind

* update the dir of ser_vi_layoutxlm model

* update dir and name of ser_vi_layoutxlm model

* update model name to StructureV2SerViLayoutXLMModel

* fix import paddle bug

---------

Co-authored-by: DefTruth <31974251+DefTruth@users.noreply.github.com>
This commit is contained in:
zengshao0622
2023-06-26 16:40:05 +08:00
committed by GitHub
parent 90e4fccbf8
commit 709ba51612
10 changed files with 1092 additions and 0 deletions

View File

@@ -18,6 +18,11 @@ from .... import FastDeployModel, ModelFormat
from .... import c_lib_wrap as C
from ...common import ProcessorManager
from .utils.ser_vi_layoutxlm.vqa_utils import *
from .utils.ser_vi_layoutxlm.transforms import *
from .utils.ser_vi_layoutxlm.operators import *
from pathlib import Path
def sort_boxes(boxes):
return C.vision.ocr.sort_boxes(boxes)
@@ -848,6 +853,7 @@ class StructureV2Layout(FastDeployModel):
def postprocessor(self, value):
self._model.postprocessor = value
class PPOCRv4(FastDeployModel):
def __init__(self, det_model=None, cls_model=None, rec_model=None):
"""Consruct a pipeline with text detector, direction classifier and text recognizer models
@@ -912,6 +918,7 @@ class PPOCRv4(FastDeployModel):
int), "The value to set `rec_batch_size` must be type of int."
self.system_.rec_batch_size = value
class PPOCRSystemv4(PPOCRv4):
def __init__(self, det_model=None, cls_model=None, rec_model=None):
logging.warning(
@@ -922,6 +929,7 @@ class PPOCRSystemv4(PPOCRv4):
def predict(self, input_image):
return super(PPOCRSystemv4, self).predict(input_image)
class PPOCRv3(FastDeployModel):
def __init__(self, det_model=None, cls_model=None, rec_model=None):
"""Consruct a pipeline with text detector, direction classifier and text recognizer models
@@ -1129,3 +1137,201 @@ class PPStructureV2TableSystem(PPStructureV2Table):
def predict(self, input_image):
return super(PPStructureV2TableSystem, self).predict(input_image)
class StructureV2SERViLayoutXLMModelPreprocessor():
def __init__(self, ser_dict_path, use_gpu=True):
"""Create a preprocessor for Ser-Vi-LayoutXLM model.
:param: ser_dict_path: (str) class file path
:param: use_gpu: (bool) whether use gpu to OCR process
"""
self._manager = None
from paddleocr import PaddleOCR
self.ocr_engine = PaddleOCR(
use_angle_cls=False,
det_model_dir=None,
rec_model_dir=None,
show_log=False,
use_gpu=use_gpu)
pre_process_list = [{
'VQATokenLabelEncode': {
'class_path': ser_dict_path,
'contains_re': False,
'ocr_engine': self.ocr_engine,
'order_method': "tb-yx"
}
}, {
'VQATokenPad': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'VQASerTokenChunk': {
'max_seq_len': 512,
'return_attention_mask': True
}
}, {
'Resize': {
'size': [224, 224]
}
}, {
'NormalizeImage': {
'std': [58.395, 57.12, 57.375],
'mean': [123.675, 116.28, 103.53],
'scale': '1',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': [
'input_ids', 'bbox', 'attention_mask', 'token_type_ids',
'image', 'labels', 'segment_offset_id', 'ocr_info',
'entities'
]
}
}]
self.preprocess_op = create_operators(pre_process_list,
{'infer_mode': True})
def _transform(self, data, ops=None):
""" transform """
if ops is None:
ops = []
for op in ops:
data = op(data)
if data is None:
return None
return data
def run(self, input_im):
"""Run preprocess of Ser-Vi-LayoutXLM model
:param: input_ims: (numpy.ndarray) input image
"""
ori_im = input_im.copy()
data = {'image': input_im}
data = transform(data, self.preprocess_op)
for idx in range(len(data)):
if isinstance(data[idx], np.ndarray):
data[idx] = np.expand_dims(data[idx], axis=0)
else:
data[idx] = [data[idx]]
return data
class StructureV2SERViLayoutXLMModelPostprocessor():
def __init__(self, class_path):
"""Create a postprocessor for Ser-Vi-LayoutXLM model.
:param: class_path: (string) class file path
"""
self.postprocessor_op = VQASerTokenLayoutLMPostProcess(class_path)
def run(self, preds, batch=None, *args, **kwargs):
"""Run postprocess of Ser-Vi-LayoutXLM model.
:param: preds: (list) results of infering
"""
return self.postprocessor_op(preds, batch, *args, **kwargs)
class StructureV2SERViLayoutXLMModel(FastDeployModel):
def __init__(self,
model_file,
params_file,
ser_dict_path,
class_path,
config_file="",
runtime_option=None,
model_format=ModelFormat.PADDLE):
"""Load SERViLayoutXLM model provided by PP-StructureV2.
:param model_file: (str)Path of model file, e.g ./ser_vi_layout_xlm/model.pdmodel.
:param params_file: (str)Path of parameter file, e.g ./ser_vi_layout_xlm/model.pdiparams, if the model format is ONNX, this parameter will be ignored.
:param ser_dict_path: (str) class file path
:param class_path: (str) class file path
:param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU.
:param model_format: (fastdeploy.ModelForamt)Model format of the loaded model.
"""
super(StructureV2SERViLayoutXLMModel, self).__init__(runtime_option)
assert self._runtime_option.backend != 0, \
"Runtime Option required backend setting."
self._model = C.vision.ocr.StructureV2SERViLayoutXLMModel(
model_file, params_file, config_file, self._runtime_option,
model_format)
assert self.initialized, "SERViLayoutXLM model initialize failed."
self.preprocessor = StructureV2SERViLayoutXLMModelPreprocessor(
ser_dict_path)
self.postprocesser = StructureV2SERViLayoutXLMModelPostprocessor(
class_path)
self.input_name_0 = self._model.get_input_info(0).name
self.input_name_1 = self._model.get_input_info(1).name
self.input_name_2 = self._model.get_input_info(2).name
self.input_name_3 = self._model.get_input_info(3).name
def predict(self, image):
assert isinstance(image,
np.ndarray), "predict recives numpy.ndarray(BGR)"
data = self.preprocessor.run(image)
infer_input = {
self.input_name_0: data[0],
self.input_name_1: data[1],
self.input_name_2: data[2],
self.input_name_3: data[3],
}
infer_result = self._model.infer(infer_input)
infer_result = infer_result[0]
post_result = self.postprocesser.run(infer_result,
segment_offset_ids=data[6],
ocr_infos=data[7])
return post_result
def batch_predict(self, image_list):
assert isinstance(image_list, list) and \
isinstance(image_list[0], np.ndarray), \
"batch_predict recives list of numpy.ndarray(BGR)"
# reading and preprocessing images
datas = None
for image in image_list:
data = self.preprocessor.run(image)
# concatenate data to batch
if datas == None:
datas = data
else:
for idx in range(len(data)):
if isinstance(data[idx], np.ndarray):
datas[idx] = np.concatenate(
(datas[idx], data[idx]), axis=0)
else:
datas[idx].extend(data[idx])
# infer
infer_inputs = {
self.input_name_0: datas[0],
self.input_name_1: datas[1],
self.input_name_2: datas[2],
self.input_name_3: datas[3],
}
infer_results = self._model.infer(infer_inputs)
infer_results = infer_results[0]
# postprocessing
post_results = self.postprocesser.run(infer_results,
segment_offset_ids=datas[6],
ocr_infos=datas[7])
return post_results