diff --git a/fastdeploy/vision.h b/fastdeploy/vision.h index 7256d2f5b..c0cb67103 100755 --- a/fastdeploy/vision.h +++ b/fastdeploy/vision.h @@ -58,6 +58,7 @@ #include "fastdeploy/vision/ocr/ppocr/dbdetector.h" #include "fastdeploy/vision/ocr/ppocr/structurev2_table.h" #include "fastdeploy/vision/ocr/ppocr/structurev2_layout.h" +#include "fastdeploy/vision/ocr/ppocr/structurev2_ser_vi_layoutxlm.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v2.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v3.h" #include "fastdeploy/vision/ocr/ppocr/ppocr_v4.h" diff --git a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc index 243d93e26..b468a20d2 100644 --- a/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc +++ b/fastdeploy/vision/ocr/ppocr/ocrmodel_pybind.cc @@ -522,5 +522,63 @@ void BindPPOCRModel(pybind11::module& m) { self.BatchPredict(images, &results); return results; }); + + pybind11::class_(m, "StructureV2SERViLayoutXLMModel") + .def(pybind11::init()) + .def("clone", + [](vision::ocr::StructureV2SERViLayoutXLMModel& self) { + return self.Clone(); + }) + .def("predict", + [](vision::ocr::StructureV2SERViLayoutXLMModel& self, + pybind11::array& data) { + throw std::runtime_error( + "StructureV2SERViLayoutXLMModel do not support predict."); + }) + .def( + "batch_predict", + [](vision::ocr::StructureV2SERViLayoutXLMModel& self, + std::vector& data) { + throw std::runtime_error( + "StructureV2SERViLayoutXLMModel do not support batch_predict."); + }) + .def("infer", + [](vision::ocr::StructureV2SERViLayoutXLMModel& self, + std::map& data) { + std::vector inputs(data.size()); + int index = 0; + for (auto iter = data.begin(); iter != data.end(); ++iter) { + std::vector data_shape; + data_shape.insert(data_shape.begin(), iter->second.shape(), + iter->second.shape() + iter->second.ndim()); + auto dtype = NumpyDataTypeToFDDataType(iter->second.dtype()); + + inputs[index].Resize(data_shape, dtype); + memcpy(inputs[index].MutableData(), iter->second.mutable_data(), + iter->second.nbytes()); + inputs[index].name = iter->first; + index += 1; + } + + std::vector outputs(self.NumOutputsOfRuntime()); + self.Infer(inputs, &outputs); + + std::vector results; + results.reserve(outputs.size()); + for (size_t i = 0; i < outputs.size(); ++i) { + auto numpy_dtype = FDDataTypeToNumpyDataType(outputs[i].dtype); + results.emplace_back( + pybind11::array(numpy_dtype, outputs[i].shape)); + memcpy(results[i].mutable_data(), outputs[i].Data(), + outputs[i].Numel() * FDDataTypeSize(outputs[i].dtype)); + } + return results; + }) + .def("get_input_info", + [](vision::ocr::StructureV2SERViLayoutXLMModel& self, int& index) { + return self.InputInfoOfRuntime(index); + }); } } // namespace fastdeploy diff --git a/fastdeploy/vision/ocr/ppocr/structurev2_ser_vi_layoutxlm.cc b/fastdeploy/vision/ocr/ppocr/structurev2_ser_vi_layoutxlm.cc new file mode 100644 index 000000000..837c5d2c1 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/structurev2_ser_vi_layoutxlm.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fastdeploy/vision/ocr/ppocr/structurev2_ser_vi_layoutxlm.h" + +#include "fastdeploy/utils/unique_ptr.h" + +namespace fastdeploy { +namespace vision { +namespace ocr { + +StructureV2SERViLayoutXLMModel::StructureV2SERViLayoutXLMModel( + const std::string& model_file, const std::string& params_file, + const std::string& config_file, const RuntimeOption& custom_option, + const ModelFormat& model_format) { + if (model_format == ModelFormat::PADDLE) { + valid_cpu_backends = {Backend::OPENVINO, Backend::PDINFER, Backend::ORT, + Backend::LITE}; + valid_gpu_backends = {Backend::ORT, Backend::PDINFER, Backend::TRT}; + valid_timvx_backends = {Backend::LITE}; + valid_ascend_backends = {Backend::LITE}; + valid_kunlunxin_backends = {Backend::LITE}; + valid_ipu_backends = {Backend::PDINFER}; + valid_directml_backends = {Backend::ORT}; + } else if (model_format == ModelFormat::SOPHGO) { + valid_sophgonpu_backends = {Backend::SOPHGOTPU}; + } else { + valid_cpu_backends = {Backend::ORT, Backend::OPENVINO}; + valid_gpu_backends = {Backend::ORT, Backend::TRT}; + valid_rknpu_backends = {Backend::RKNPU2}; + valid_directml_backends = {Backend::ORT}; + valid_horizon_backends = {Backend::HORIZONNPU}; + } + + runtime_option = custom_option; + runtime_option.model_format = model_format; + runtime_option.model_file = model_file; + runtime_option.params_file = params_file; + initialized = Initialize(); +} + +std::unique_ptr +StructureV2SERViLayoutXLMModel::Clone() const { + std::unique_ptr clone_model = + utils::make_unique( + StructureV2SERViLayoutXLMModel(*this)); + clone_model->SetRuntime(clone_model->CloneRuntime()); + return clone_model; +} + +bool StructureV2SERViLayoutXLMModel::Initialize() { + if (!InitRuntime()) { + FDERROR << "Failed to initialize fastdeploy backend." << std::endl; + return false; + } + return true; +} + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy \ No newline at end of file diff --git a/fastdeploy/vision/ocr/ppocr/structurev2_ser_vi_layoutxlm.h b/fastdeploy/vision/ocr/ppocr/structurev2_ser_vi_layoutxlm.h new file mode 100644 index 000000000..3da12b3c4 --- /dev/null +++ b/fastdeploy/vision/ocr/ppocr/structurev2_ser_vi_layoutxlm.h @@ -0,0 +1,60 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "fastdeploy/fastdeploy_model.h" +#include "fastdeploy/vision/common/processors/transform.h" + +namespace fastdeploy { +namespace vision { +/** \brief All classification model APIs are defined inside this namespace + * + */ +namespace ocr { +/*! @brief StructureV2SERViLayoutXLM model object used when to load a StructureV2SERViLayoutXLM model exported by StructureV2SERViLayoutXLMModel repository + */ +class FASTDEPLOY_DECL StructureV2SERViLayoutXLMModel : public FastDeployModel { + public: + /** \brief Set path of model file and configuration file, and the configuration of runtime + * + * \param[in] model_file Path of model file, e.g ser_vi_layoutxlm/model.pdmodel + * \param[in] params_file Path of parameter file, e.g ser_vi_layoutxlm/model.pdiparams, if the model format is ONNX, this parameter will be ignored + * \param[in] config_file Path of configuration file for deployment, e.g ser_vi_layoutxlm/infer_cfg.yml + * \param[in] custom_option RuntimeOption for inference, the default will use cpu, and choose the backend defined in `valid_cpu_backends` + * \param[in] model_format Model format of the loaded model, default is Paddle format + */ + StructureV2SERViLayoutXLMModel(const std::string& model_file, + const std::string& params_file, + const std::string& config_file, + const RuntimeOption& custom_option = RuntimeOption(), + const ModelFormat& model_format = ModelFormat::PADDLE); + + /** \brief Clone a new StructureV2SERViLayoutXLMModel with less memory usage when multiple instances of the same model are created + * + * \return new StructureV2SERViLayoutXLMModel* type unique pointer + */ + virtual std::unique_ptr Clone() const; + + /// Get model's name + virtual std::string ModelName() const { + return "StructureV2SERViLayoutXLMModel"; +} + + protected: + bool Initialize(); +}; + +} // namespace ocr +} // namespace vision +} // namespace fastdeploy diff --git a/python/fastdeploy/vision/ocr/ppocr/__init__.py b/python/fastdeploy/vision/ocr/ppocr/__init__.py index 4f04b3210..7cec60039 100755 --- a/python/fastdeploy/vision/ocr/ppocr/__init__.py +++ b/python/fastdeploy/vision/ocr/ppocr/__init__.py @@ -18,6 +18,11 @@ from .... import FastDeployModel, ModelFormat from .... import c_lib_wrap as C from ...common import ProcessorManager +from .utils.ser_vi_layoutxlm.vqa_utils import * +from .utils.ser_vi_layoutxlm.transforms import * +from .utils.ser_vi_layoutxlm.operators import * +from pathlib import Path + def sort_boxes(boxes): return C.vision.ocr.sort_boxes(boxes) @@ -848,6 +853,7 @@ class StructureV2Layout(FastDeployModel): def postprocessor(self, value): self._model.postprocessor = value + class PPOCRv4(FastDeployModel): def __init__(self, det_model=None, cls_model=None, rec_model=None): """Consruct a pipeline with text detector, direction classifier and text recognizer models @@ -912,6 +918,7 @@ class PPOCRv4(FastDeployModel): int), "The value to set `rec_batch_size` must be type of int." self.system_.rec_batch_size = value + class PPOCRSystemv4(PPOCRv4): def __init__(self, det_model=None, cls_model=None, rec_model=None): logging.warning( @@ -922,6 +929,7 @@ class PPOCRSystemv4(PPOCRv4): def predict(self, input_image): return super(PPOCRSystemv4, self).predict(input_image) + class PPOCRv3(FastDeployModel): def __init__(self, det_model=None, cls_model=None, rec_model=None): """Consruct a pipeline with text detector, direction classifier and text recognizer models @@ -1129,3 +1137,201 @@ class PPStructureV2TableSystem(PPStructureV2Table): def predict(self, input_image): return super(PPStructureV2TableSystem, self).predict(input_image) + + +class StructureV2SERViLayoutXLMModelPreprocessor(): + def __init__(self, ser_dict_path, use_gpu=True): + """Create a preprocessor for Ser-Vi-LayoutXLM model. + :param: ser_dict_path: (str) class file path + :param: use_gpu: (bool) whether use gpu to OCR process + """ + self._manager = None + from paddleocr import PaddleOCR + self.ocr_engine = PaddleOCR( + use_angle_cls=False, + det_model_dir=None, + rec_model_dir=None, + show_log=False, + use_gpu=use_gpu) + + pre_process_list = [{ + 'VQATokenLabelEncode': { + 'class_path': ser_dict_path, + 'contains_re': False, + 'ocr_engine': self.ocr_engine, + 'order_method': "tb-yx" + } + }, { + 'VQATokenPad': { + 'max_seq_len': 512, + 'return_attention_mask': True + } + }, { + 'VQASerTokenChunk': { + 'max_seq_len': 512, + 'return_attention_mask': True + } + }, { + 'Resize': { + 'size': [224, 224] + } + }, { + 'NormalizeImage': { + 'std': [58.395, 57.12, 57.375], + 'mean': [123.675, 116.28, 103.53], + 'scale': '1', + 'order': 'hwc' + } + }, { + 'ToCHWImage': None + }, { + 'KeepKeys': { + 'keep_keys': [ + 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', + 'image', 'labels', 'segment_offset_id', 'ocr_info', + 'entities' + ] + } + }] + + self.preprocess_op = create_operators(pre_process_list, + {'infer_mode': True}) + + def _transform(self, data, ops=None): + """ transform """ + if ops is None: + ops = [] + for op in ops: + data = op(data) + if data is None: + return None + return data + + def run(self, input_im): + """Run preprocess of Ser-Vi-LayoutXLM model + :param: input_ims: (numpy.ndarray) input image + """ + ori_im = input_im.copy() + data = {'image': input_im} + data = transform(data, self.preprocess_op) + + for idx in range(len(data)): + if isinstance(data[idx], np.ndarray): + data[idx] = np.expand_dims(data[idx], axis=0) + else: + data[idx] = [data[idx]] + + return data + + +class StructureV2SERViLayoutXLMModelPostprocessor(): + def __init__(self, class_path): + """Create a postprocessor for Ser-Vi-LayoutXLM model. + :param: class_path: (string) class file path + """ + self.postprocessor_op = VQASerTokenLayoutLMPostProcess(class_path) + + def run(self, preds, batch=None, *args, **kwargs): + """Run postprocess of Ser-Vi-LayoutXLM model. + :param: preds: (list) results of infering + """ + return self.postprocessor_op(preds, batch, *args, **kwargs) + + +class StructureV2SERViLayoutXLMModel(FastDeployModel): + def __init__(self, + model_file, + params_file, + ser_dict_path, + class_path, + config_file="", + runtime_option=None, + model_format=ModelFormat.PADDLE): + """Load SERViLayoutXLM model provided by PP-StructureV2. + + :param model_file: (str)Path of model file, e.g ./ser_vi_layout_xlm/model.pdmodel. + :param params_file: (str)Path of parameter file, e.g ./ser_vi_layout_xlm/model.pdiparams, if the model format is ONNX, this parameter will be ignored. + :param ser_dict_path: (str) class file path + :param class_path: (str) class file path + :param runtime_option: (fastdeploy.RuntimeOption)RuntimeOption for inference this model, if it's None, will use the default backend on CPU. + :param model_format: (fastdeploy.ModelForamt)Model format of the loaded model. + """ + super(StructureV2SERViLayoutXLMModel, self).__init__(runtime_option) + + assert self._runtime_option.backend != 0, \ + "Runtime Option required backend setting." + self._model = C.vision.ocr.StructureV2SERViLayoutXLMModel( + model_file, params_file, config_file, self._runtime_option, + model_format) + + assert self.initialized, "SERViLayoutXLM model initialize failed." + + self.preprocessor = StructureV2SERViLayoutXLMModelPreprocessor( + ser_dict_path) + self.postprocesser = StructureV2SERViLayoutXLMModelPostprocessor( + class_path) + + self.input_name_0 = self._model.get_input_info(0).name + self.input_name_1 = self._model.get_input_info(1).name + self.input_name_2 = self._model.get_input_info(2).name + self.input_name_3 = self._model.get_input_info(3).name + + def predict(self, image): + assert isinstance(image, + np.ndarray), "predict recives numpy.ndarray(BGR)" + + data = self.preprocessor.run(image) + infer_input = { + self.input_name_0: data[0], + self.input_name_1: data[1], + self.input_name_2: data[2], + self.input_name_3: data[3], + } + + infer_result = self._model.infer(infer_input) + infer_result = infer_result[0] + + post_result = self.postprocesser.run(infer_result, + segment_offset_ids=data[6], + ocr_infos=data[7]) + + return post_result + + def batch_predict(self, image_list): + assert isinstance(image_list, list) and \ + isinstance(image_list[0], np.ndarray), \ + "batch_predict recives list of numpy.ndarray(BGR)" + + # reading and preprocessing images + datas = None + for image in image_list: + data = self.preprocessor.run(image) + + # concatenate data to batch + if datas == None: + datas = data + else: + for idx in range(len(data)): + if isinstance(data[idx], np.ndarray): + datas[idx] = np.concatenate( + (datas[idx], data[idx]), axis=0) + else: + datas[idx].extend(data[idx]) + + # infer + infer_inputs = { + self.input_name_0: datas[0], + self.input_name_1: datas[1], + self.input_name_2: datas[2], + self.input_name_3: datas[3], + } + + infer_results = self._model.infer(infer_inputs) + infer_results = infer_results[0] + + # postprocessing + post_results = self.postprocesser.run(infer_results, + segment_offset_ids=datas[6], + ocr_infos=datas[7]) + + return post_results diff --git a/python/fastdeploy/vision/ocr/ppocr/utils/__init__.py b/python/fastdeploy/vision/ocr/ppocr/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/__init__.py b/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/operators.py b/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/operators.py new file mode 100644 index 000000000..449572b8e --- /dev/null +++ b/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/operators.py @@ -0,0 +1,91 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import cv2 +import numpy as np + + +class Resize(object): + def __init__(self, size=(640, 640), **kwargs): + self.size = size + + def resize_image(self, img): + resize_h, resize_w = self.size + ori_h, ori_w = img.shape[:2] # (h, w, c) + ratio_h = float(resize_h) / ori_h + ratio_w = float(resize_w) / ori_w + img = cv2.resize(img, (int(resize_w), int(resize_h))) + return img, [ratio_h, ratio_w] + + def __call__(self, data): + img = data['image'] + if 'polys' in data: + text_polys = data['polys'] + + img_resize, [ratio_h, ratio_w] = self.resize_image(img) + if 'polys' in data: + new_boxes = [] + for box in text_polys: + new_box = [] + for cord in box: + new_box.append([cord[0] * ratio_w, cord[1] * ratio_h]) + new_boxes.append(new_box) + data['polys'] = np.array(new_boxes, dtype=np.float32) + data['image'] = img_resize + return data + + +class NormalizeImage(object): + """ normalize image such as substract mean, divide std + """ + + def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + assert isinstance(img, + np.ndarray), "invalid input 'img' in NormalizeImage" + data['image'] = ( + img.astype('float32') * self.scale - self.mean) / self.std + return data + + +class ToCHWImage(object): + """ convert hwc image to chw image + """ + + def __init__(self, **kwargs): + pass + + def __call__(self, data): + img = data['image'] + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + data['image'] = img.transpose((2, 0, 1)) + return data + + +class KeepKeys(object): + def __init__(self, keep_keys, **kwargs): + self.keep_keys = keep_keys + + def __call__(self, data): + data_list = [] + for key in self.keep_keys: + data_list.append(data[key]) + return data_list diff --git a/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/transforms.py b/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/transforms.py new file mode 100644 index 000000000..dbdb6b0c9 --- /dev/null +++ b/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/transforms.py @@ -0,0 +1,35 @@ +from .vqa_utils import * +from .operators import * + + +def transform(data, ops=None): + """ transform """ + if ops is None: + ops = [] + for op in ops: + data = op(data) + if data is None: + return None + return data + + +def create_operators(op_param_list, global_config=None): + """ + create operators based on the config + + Args: + params(list): a dict list, used to create some operators + """ + assert isinstance(op_param_list, list), ( + 'operator config should be a list') + ops = [] + for operator in op_param_list: + assert isinstance(operator, + dict) and len(operator) == 1, "yaml format error" + op_name = list(operator)[0] + param = {} if operator[op_name] is None else operator[op_name] + if global_config is not None: + param.update(global_config) + op = eval(op_name)(**param) + ops.append(op) + return ops diff --git a/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/vqa_utils.py b/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/vqa_utils.py new file mode 100644 index 000000000..68e9d7a1e --- /dev/null +++ b/python/fastdeploy/vision/ocr/ppocr/utils/ser_vi_layoutxlm/vqa_utils.py @@ -0,0 +1,569 @@ +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import copy +import numpy as np +import json +import copy +from copy import deepcopy + +from collections import defaultdict + + +def order_by_tbyx(ocr_info): + res = sorted(ocr_info, key=lambda r: (r["bbox"][1], r["bbox"][0])) + for i in range(len(res) - 1): + for j in range(i, 0, -1): + if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and \ + (res[j + 1]["bbox"][0] < res[j]["bbox"][0]): + tmp = deepcopy(res[j]) + res[j] = deepcopy(res[j + 1]) + res[j + 1] = deepcopy(tmp) + else: + break + return res + + +def load_vqa_bio_label_maps(label_map_path): + with open(label_map_path, "r", encoding='utf-8') as fin: + lines = fin.readlines() + old_lines = [line.strip() for line in lines] + lines = ["O"] + for line in old_lines: + # "O" has already been in lines + if line.upper() in ["OTHER", "OTHERS", "IGNORE"]: + continue + lines.append(line) + labels = ["O"] + for line in lines[1:]: + labels.append("B-" + line) + labels.append("I-" + line) + label2id_map = {label.upper(): idx for idx, label in enumerate(labels)} + id2label_map = {idx: label.upper() for idx, label in enumerate(labels)} + return label2id_map, id2label_map + + +class VQATokenLabelEncode(object): + """ + Label encode for NLP VQA methods + """ + + def __init__(self, + class_path, + contains_re=False, + add_special_ids=False, + algorithm='LayoutXLM', + use_textline_bbox_info=True, + order_method=None, + infer_mode=False, + ocr_engine=None, + **kwargs): + super(VQATokenLabelEncode, self).__init__() + from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer + tokenizer_dict = { + 'LayoutXLM': { + 'class': LayoutXLMTokenizer, + 'pretrained_model': 'layoutxlm-base-uncased' + }, + 'LayoutLM': { + 'class': LayoutLMTokenizer, + 'pretrained_model': 'layoutlm-base-uncased' + }, + 'LayoutLMv2': { + 'class': LayoutLMv2Tokenizer, + 'pretrained_model': 'layoutlmv2-base-uncased' + } + } + self.contains_re = contains_re + tokenizer_config = tokenizer_dict[algorithm] + self.tokenizer = tokenizer_config['class'].from_pretrained( + tokenizer_config['pretrained_model']) + self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path) + self.add_special_ids = add_special_ids + self.infer_mode = infer_mode + self.ocr_engine = ocr_engine + self.use_textline_bbox_info = use_textline_bbox_info + self.order_method = order_method + assert self.order_method in [None, "tb-yx"] + + def split_bbox(self, bbox, text, tokenizer): + words = text.split() + token_bboxes = [] + curr_word_idx = 0 + x1, y1, x2, y2 = bbox + unit_w = (x2 - x1) / len(text) + for idx, word in enumerate(words): + curr_w = len(word) * unit_w + word_bbox = [x1, y1, x1 + curr_w, y2] + token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word))) + x1 += (len(word) + 1) * unit_w + return token_bboxes + + def filter_empty_contents(self, ocr_info): + """ + find out the empty texts and remove the links + """ + new_ocr_info = [] + empty_index = [] + for idx, info in enumerate(ocr_info): + if len(info["transcription"]) > 0: + new_ocr_info.append(copy.deepcopy(info)) + else: + empty_index.append(info["id"]) + + for idx, info in enumerate(new_ocr_info): + new_link = [] + for link in info["linking"]: + if link[0] in empty_index or link[1] in empty_index: + continue + new_link.append(link) + new_ocr_info[idx]["linking"] = new_link + return new_ocr_info + + def __call__(self, data): + # load bbox and label info + ocr_info = self._load_ocr_info(data) + + for idx in range(len(ocr_info)): + if "bbox" not in ocr_info[idx]: + ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][ + "points"]) + + if self.order_method == "tb-yx": + ocr_info = order_by_tbyx(ocr_info) + + # for re + train_re = self.contains_re and not self.infer_mode + if train_re: + ocr_info = self.filter_empty_contents(ocr_info) + + height, width, _ = data['image'].shape + + words_list = [] + bbox_list = [] + input_ids_list = [] + token_type_ids_list = [] + segment_offset_id = [] + gt_label_list = [] + + entities = [] + + if train_re: + relations = [] + id2label = {} + entity_id_to_index_map = {} + empty_entity = set() + + data['ocr_info'] = copy.deepcopy(ocr_info) + + for info in ocr_info: + text = info["transcription"] + if len(text) <= 0: + continue + if train_re: + # for re + if len(text) == 0: + empty_entity.add(info["id"]) + continue + id2label[info["id"]] = info["label"] + relations.extend([tuple(sorted(l)) for l in info["linking"]]) + # smooth_box + info["bbox"] = self.trans_poly_to_bbox(info["points"]) + + encode_res = self.tokenizer.encode( + text, + pad_to_max_seq_len=False, + return_attention_mask=True, + return_token_type_ids=True) + + if not self.add_special_ids: + # TODO: use tok.all_special_ids to remove + encode_res["input_ids"] = encode_res["input_ids"][1:-1] + encode_res["token_type_ids"] = encode_res["token_type_ids"][1: + -1] + encode_res["attention_mask"] = encode_res["attention_mask"][1: + -1] + + if self.use_textline_bbox_info: + bbox = [info["bbox"]] * len(encode_res["input_ids"]) + else: + bbox = self.split_bbox(info["bbox"], info["transcription"], + self.tokenizer) + if len(bbox) <= 0: + continue + bbox = self._smooth_box(bbox, height, width) + if self.add_special_ids: + bbox.insert(0, [0, 0, 0, 0]) + bbox.append([0, 0, 0, 0]) + + # parse label + if not self.infer_mode: + label = info['label'] + gt_label = self._parse_label(label, encode_res) + + # construct entities for re + if train_re: + if gt_label[0] != self.label2id_map["O"]: + entity_id_to_index_map[info["id"]] = len(entities) + label = label.upper() + entities.append({ + "start": len(input_ids_list), + "end": + len(input_ids_list) + len(encode_res["input_ids"]), + "label": label.upper(), + }) + else: + entities.append({ + "start": len(input_ids_list), + "end": len(input_ids_list) + len(encode_res["input_ids"]), + "label": 'O', + }) + input_ids_list.extend(encode_res["input_ids"]) + token_type_ids_list.extend(encode_res["token_type_ids"]) + bbox_list.extend(bbox) + words_list.append(text) + segment_offset_id.append(len(input_ids_list)) + if not self.infer_mode: + gt_label_list.extend(gt_label) + + data['input_ids'] = input_ids_list + data['token_type_ids'] = token_type_ids_list + data['bbox'] = bbox_list + data['attention_mask'] = [1] * len(input_ids_list) + data['labels'] = gt_label_list + data['segment_offset_id'] = segment_offset_id + data['tokenizer_params'] = dict( + padding_side=self.tokenizer.padding_side, + pad_token_type_id=self.tokenizer.pad_token_type_id, + pad_token_id=self.tokenizer.pad_token_id) + data['entities'] = entities + + if train_re: + data['relations'] = relations + data['id2label'] = id2label + data['empty_entity'] = empty_entity + data['entity_id_to_index_map'] = entity_id_to_index_map + return data + + def trans_poly_to_bbox(self, poly): + x1 = int(np.min([p[0] for p in poly])) + x2 = int(np.max([p[0] for p in poly])) + y1 = int(np.min([p[1] for p in poly])) + y2 = int(np.max([p[1] for p in poly])) + return [x1, y1, x2, y2] + + def _load_ocr_info(self, data): + if self.infer_mode: + ocr_result = self.ocr_engine.ocr(data['image'], cls=False)[0] + ocr_info = [] + for res in ocr_result: + ocr_info.append({ + "transcription": res[1][0], + "bbox": self.trans_poly_to_bbox(res[0]), + "points": res[0], + }) + return ocr_info + else: + info = data['label'] + # read text info + info_dict = json.loads(info) + return info_dict + + def _smooth_box(self, bboxes, height, width): + bboxes = np.array(bboxes) + bboxes[:, 0] = bboxes[:, 0] * 1000 / width + bboxes[:, 2] = bboxes[:, 2] * 1000 / width + bboxes[:, 1] = bboxes[:, 1] * 1000 / height + bboxes[:, 3] = bboxes[:, 3] * 1000 / height + bboxes = bboxes.astype("int64").tolist() + return bboxes + + def _parse_label(self, label, encode_res): + gt_label = [] + if label.lower() in ["other", "others", "ignore"]: + gt_label.extend([0] * len(encode_res["input_ids"])) + else: + gt_label.append(self.label2id_map[("b-" + label).upper()]) + gt_label.extend([self.label2id_map[("i-" + label).upper()]] * + (len(encode_res["input_ids"]) - 1)) + return gt_label + + +class VQATokenPad(object): + def __init__(self, + max_seq_len=512, + pad_to_max_seq_len=True, + return_attention_mask=True, + return_token_type_ids=True, + truncation_strategy="longest_first", + return_overflowing_tokens=False, + return_special_tokens_mask=False, + infer_mode=False, + **kwargs): + + self.max_seq_len = max_seq_len + self.pad_to_max_seq_len = max_seq_len + self.return_attention_mask = return_attention_mask + self.return_token_type_ids = return_token_type_ids + self.truncation_strategy = truncation_strategy + self.return_overflowing_tokens = return_overflowing_tokens + self.return_special_tokens_mask = return_special_tokens_mask + self.infer_mode = infer_mode + + def __call__(self, data): + import paddle + self.pad_token_label_id = paddle.nn.CrossEntropyLoss().ignore_index + needs_to_be_padded = self.pad_to_max_seq_len and len(data[ + "input_ids"]) < self.max_seq_len + + if needs_to_be_padded: + if 'tokenizer_params' in data: + tokenizer_params = data.pop('tokenizer_params') + else: + tokenizer_params = dict( + padding_side='right', pad_token_type_id=0, pad_token_id=1) + + difference = self.max_seq_len - len(data["input_ids"]) + if tokenizer_params['padding_side'] == 'right': + if self.return_attention_mask: + data["attention_mask"] = [1] * len(data[ + "input_ids"]) + [0] * difference + if self.return_token_type_ids: + data["token_type_ids"] = ( + data["token_type_ids"] + + [tokenizer_params['pad_token_type_id']] * difference) + if self.return_special_tokens_mask: + data["special_tokens_mask"] = data[ + "special_tokens_mask"] + [1] * difference + data["input_ids"] = data["input_ids"] + [ + tokenizer_params['pad_token_id'] + ] * difference + if not self.infer_mode: + data["labels"] = data[ + "labels"] + [self.pad_token_label_id] * difference + data["bbox"] = data["bbox"] + [[0, 0, 0, 0]] * difference + elif tokenizer_params['padding_side'] == 'left': + if self.return_attention_mask: + data["attention_mask"] = [0] * difference + [ + 1 + ] * len(data["input_ids"]) + if self.return_token_type_ids: + data["token_type_ids"] = ( + [tokenizer_params['pad_token_type_id']] * difference + + data["token_type_ids"]) + if self.return_special_tokens_mask: + data["special_tokens_mask"] = [ + 1 + ] * difference + data["special_tokens_mask"] + data["input_ids"] = [tokenizer_params['pad_token_id'] + ] * difference + data["input_ids"] + if not self.infer_mode: + data["labels"] = [self.pad_token_label_id + ] * difference + data["labels"] + data["bbox"] = [[0, 0, 0, 0]] * difference + data["bbox"] + else: + if self.return_attention_mask: + data["attention_mask"] = [1] * len(data["input_ids"]) + + for key in data: + if key in [ + 'input_ids', 'labels', 'token_type_ids', 'bbox', + 'attention_mask' + ]: + if self.infer_mode: + if key != 'labels': + length = min(len(data[key]), self.max_seq_len) + data[key] = data[key][:length] + else: + continue + data[key] = np.array(data[key], dtype='int64') + return data + + +class VQASerTokenChunk(object): + def __init__(self, max_seq_len=512, infer_mode=False, **kwargs): + self.max_seq_len = max_seq_len + self.infer_mode = infer_mode + + def __call__(self, data): + encoded_inputs_all = [] + seq_len = len(data['input_ids']) + for index in range(0, seq_len, self.max_seq_len): + chunk_beg = index + chunk_end = min(index + self.max_seq_len, seq_len) + encoded_inputs_example = {} + for key in data: + if key in [ + 'label', 'input_ids', 'labels', 'token_type_ids', + 'bbox', 'attention_mask' + ]: + if self.infer_mode and key == 'labels': + encoded_inputs_example[key] = data[key] + else: + encoded_inputs_example[key] = data[key][chunk_beg: + chunk_end] + else: + encoded_inputs_example[key] = data[key] + + encoded_inputs_all.append(encoded_inputs_example) + if len(encoded_inputs_all) == 0: + return None + return encoded_inputs_all[0] + + +class VQAReTokenChunk(object): + def __init__(self, + max_seq_len=512, + entities_labels=None, + infer_mode=False, + **kwargs): + self.max_seq_len = max_seq_len + self.entities_labels = { + 'HEADER': 0, + 'QUESTION': 1, + 'ANSWER': 2 + } if entities_labels is None else entities_labels + self.infer_mode = infer_mode + + def __call__(self, data): + # prepare data + entities = data.pop('entities') + relations = data.pop('relations') + encoded_inputs_all = [] + for index in range(0, len(data["input_ids"]), self.max_seq_len): + item = {} + for key in data: + if key in [ + 'label', 'input_ids', 'labels', 'token_type_ids', + 'bbox', 'attention_mask' + ]: + if self.infer_mode and key == 'labels': + item[key] = data[key] + else: + item[key] = data[key][index:index + self.max_seq_len] + else: + item[key] = data[key] + # select entity in current chunk + entities_in_this_span = [] + global_to_local_map = {} # + for entity_id, entity in enumerate(entities): + if (index <= entity["start"] < index + self.max_seq_len and + index <= entity["end"] < index + self.max_seq_len): + entity["start"] = entity["start"] - index + entity["end"] = entity["end"] - index + global_to_local_map[entity_id] = len(entities_in_this_span) + entities_in_this_span.append(entity) + + # select relations in current chunk + relations_in_this_span = [] + for relation in relations: + if (index <= relation["start_index"] < index + self.max_seq_len + and index <= relation["end_index"] < + index + self.max_seq_len): + relations_in_this_span.append({ + "head": global_to_local_map[relation["head"]], + "tail": global_to_local_map[relation["tail"]], + "start_index": relation["start_index"] - index, + "end_index": relation["end_index"] - index, + }) + item.update({ + "entities": self.reformat(entities_in_this_span), + "relations": self.reformat(relations_in_this_span), + }) + if len(item['entities']) > 0: + item['entities']['label'] = [ + self.entities_labels[x] for x in item['entities']['label'] + ] + encoded_inputs_all.append(item) + if len(encoded_inputs_all) == 0: + return None + return encoded_inputs_all[0] + + def reformat(self, data): + new_data = defaultdict(list) + for item in data: + for k, v in item.items(): + new_data[k].append(v) + return new_data + + +class VQASerTokenLayoutLMPostProcess(object): + """ Convert between text-label and text-index """ + + def __init__(self, class_path, **kwargs): + super(VQASerTokenLayoutLMPostProcess, self).__init__() + label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path) + + self.label2id_map_for_draw = dict() + for key in label2id_map: + if key.startswith("I-"): + self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]] + else: + self.label2id_map_for_draw[key] = label2id_map[key] + + self.id2label_map_for_show = dict() + for key in self.label2id_map_for_draw: + val = self.label2id_map_for_draw[key] + if key == "O": + self.id2label_map_for_show[val] = key + if key.startswith("B-") or key.startswith("I-"): + self.id2label_map_for_show[val] = key[2:] + else: + self.id2label_map_for_show[val] = key + + def __call__(self, preds, batch=None, *args, **kwargs): + import paddle + if isinstance(preds, tuple): + preds = preds[0] + if isinstance(preds, paddle.Tensor): + preds = preds.numpy() + + if batch is not None: + return self._metric(preds, batch[5]) + else: + return self._infer(preds, **kwargs) + + def _metric(self, preds, label): + pred_idxs = preds.argmax(axis=2) + decode_out_list = [[] for _ in range(pred_idxs.shape[0])] + label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])] + + for i in range(pred_idxs.shape[0]): + for j in range(pred_idxs.shape[1]): + if label[i, j] != -100: + label_decode_out_list[i].append(self.id2label_map[label[ + i, j]]) + decode_out_list[i].append(self.id2label_map[pred_idxs[i, + j]]) + return decode_out_list, label_decode_out_list + + def _infer(self, preds, segment_offset_ids, ocr_infos): + results = [] + + for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids, + ocr_infos): + pred = np.argmax(pred, axis=1) + pred = [self.id2label_map[idx] for idx in pred] + + for idx in range(len(segment_offset_id)): + if idx == 0: + start_id = 0 + else: + start_id = segment_offset_id[idx - 1] + + end_id = segment_offset_id[idx] + + curr_pred = pred[start_id:end_id] + curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred] + + if len(curr_pred) <= 0: + pred_id = 0 + else: + counts = np.bincount(curr_pred) + pred_id = np.argmax(counts) + ocr_info[idx]["pred_id"] = int(pred_id) + ocr_info[idx]["pred"] = self.id2label_map_for_show[int( + pred_id)] + results.append(ocr_info) + return results