mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
Sync v2.0 version of code to github repo
This commit is contained in:
@@ -14,159 +14,156 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
# cipher_token=WjI1fQOvhN # do not edit this line
|
||||
|
||||
import os
|
||||
import re
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from typing import Dict, Optional, Tuple, List
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
from paddlenlp.transformers import AddedToken, PretrainedTokenizer
|
||||
from paddlenlp.utils import logger
|
||||
|
||||
__all__ = ["ErnieBotTokenizer"]
|
||||
import paddle
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "spm.model"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {},
|
||||
"tokenizer_file": {},
|
||||
}
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}
|
||||
from paddleformers.utils.log import logger
|
||||
from paddleformers.transformers import PretrainedTokenizer
|
||||
from paddleformers.transformers.tokenizer_utils_base import (
|
||||
PaddingStrategy,
|
||||
TextInput,
|
||||
)
|
||||
|
||||
|
||||
class ErnieBotTokenizer(PretrainedTokenizer):
|
||||
"""
|
||||
Construct a ErnieBot tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
一个更好用的 `ErnieBotToknizer`,
|
||||
能 encode 目前 sft/ppo 阶段的特殊token,也支持多模态。
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
resource_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
resource_files_names = {
|
||||
"vocab_file": "tokenizer.model",
|
||||
}
|
||||
pretrained_resource_files_map = {"vocab_file": {"ernie-bot-10b": None}}
|
||||
pretrained_init_configuration = {
|
||||
"ernie-bot-10b": {},
|
||||
}
|
||||
model_input_names = ["input_ids", "position_ids", "attention_mask", "labels"]
|
||||
padding_side = "right"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
cls_token="<cls>",
|
||||
eos_token="</s>",
|
||||
mask_token="<mask:0>",
|
||||
pad_token="<pad>",
|
||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
add_bos_token=True,
|
||||
add_eos_token=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
sep_token="<sep>",
|
||||
unk_token="<unk>",
|
||||
additional_special_tokens=None,
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_file = vocab_file
|
||||
self.add_bos_token = add_bos_token
|
||||
self.add_eos_token = add_eos_token
|
||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(vocab_file)
|
||||
bos_token = AddedToken(bos_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
bos_token, str) else bos_token
|
||||
eos_token = AddedToken(eos_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
eos_token, str) else eos_token
|
||||
unk_token = AddedToken(unk_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
unk_token, str) else unk_token
|
||||
pad_token = AddedToken(pad_token,
|
||||
lstrip=False, rstrip=False) if isinstance(
|
||||
pad_token, str) else pad_token
|
||||
"""doc"""
|
||||
if additional_special_tokens is None:
|
||||
additional_special_tokens = ["<mask:1>", "<mask:7>"]
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
cls_token=cls_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
mask_token=mask_token,
|
||||
pad_token=pad_token,
|
||||
add_bos_token=add_bos_token,
|
||||
add_eos_token=add_eos_token,
|
||||
sep_token=sep_token,
|
||||
unk_token=unk_token,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
verbose=False,
|
||||
sp_model_kwargs=self.sp_model_kwargs,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
# for eb35 reader
|
||||
self.bos_id = self.bos_token_id
|
||||
self.eos_id = self.eos_token_id
|
||||
self.sep_id = self.sep_token_id
|
||||
self.pad_id = self.pad_token_id
|
||||
self.unk_id = self.unk_token_id
|
||||
self.vocab_file = vocab_file
|
||||
self.sp_model = spm.SentencePieceProcessor()
|
||||
self.sp_model.Load(vocab_file)
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
return state
|
||||
@property
|
||||
def space_token(self):
|
||||
"""doc"""
|
||||
return "<mask:1>"
|
||||
|
||||
def __setstate__(self, d):
|
||||
self.__dict__ = d
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(self.vocab_file)
|
||||
@property
|
||||
def space_token_id(self):
|
||||
"""doc"""
|
||||
return self.sp_model.piece_to_id("<mask:1>")
|
||||
|
||||
@property
|
||||
def gend_token(self):
|
||||
"""doc"""
|
||||
return "<mask:7>"
|
||||
|
||||
@property
|
||||
def gend_token_id(self):
|
||||
"""doc"""
|
||||
return self.sp_model.piece_to_id("<mask:7>")
|
||||
|
||||
@property
|
||||
def im_start_id(self):
|
||||
"""doc"""
|
||||
return self.sp_model.piece_to_id("<|im_start|>")
|
||||
|
||||
@property
|
||||
def im_end_id(self):
|
||||
"""doc"""
|
||||
return self.sp_model.piece_to_id("<|im_end|>")
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Returns vocab size"""
|
||||
return self.sp_model.get_piece_size()
|
||||
"""doc"""
|
||||
return self.sp_model.vocab_size()
|
||||
|
||||
def get_vocab(self):
|
||||
"""Returns vocab as a dict"""
|
||||
vocab = {
|
||||
self.convert_ids_to_tokens(i): i
|
||||
for i in range(self.vocab_size)
|
||||
}
|
||||
"""doc"""
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def tokenize(self, text):
|
||||
"""Returns a tokenized string."""
|
||||
return self._tokenize(text)
|
||||
|
||||
def _tokenize(self, text):
|
||||
"""Returns a tokenized string."""
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
def decode(self,
|
||||
tokens,
|
||||
skip_special_tokens=False,
|
||||
clean_up_tokenization_spaces=False):
|
||||
"""Returns a tokenized string."""
|
||||
return self.sp_model.decode(tokens)
|
||||
"""doc"""
|
||||
return self.sp_model.encode_as_pieces(text)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
"""doc"""
|
||||
return self.sp_model.piece_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
return token
|
||||
def _convert_id_to_token(self, id):
|
||||
"""doc"""
|
||||
return self.sp_model.id_to_piece(id)
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for i, token in enumerate(tokens):
|
||||
# prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special and i != 0:
|
||||
out_string += " "
|
||||
# if not prev_is_special:
|
||||
# out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
# prev_is_special = True
|
||||
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
# prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
return out_string
|
||||
return out_string # .strip()
|
||||
|
||||
def save_vocabulary(self,
|
||||
save_directory,
|
||||
filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
def prepare_for_model(self, *args, **kwargs):
|
||||
"""doc"""
|
||||
if "add_special_tokens" in kwargs:
|
||||
kwargs.pop("add_special_tokens")
|
||||
# logger.warning(f'ErnieBotTokenizer v2 does not support `add_special_tokens`')
|
||||
return super().prepare_for_model(*args, **kwargs)
|
||||
|
||||
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
Args:
|
||||
@@ -176,94 +173,212 @@ class ErnieBotTokenizer(PretrainedTokenizer):
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(
|
||||
f"Vocabulary path ({save_directory}) should be a directory")
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "") +
|
||||
VOCAB_FILES_NAMES["vocab_file"])
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
||||
out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
(filename_prefix + "-" if filename_prefix else "") + self.resource_files_names["vocab_file"],
|
||||
)
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
return (out_vocab_file,)
|
||||
|
||||
return (out_vocab_file, )
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
""" build_inputs_with_special_tokens """
|
||||
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
||||
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
||||
|
||||
output = bos_token_id + token_ids_0 + eos_token_id
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output = output + bos_token_id + token_ids_1 + eos_token_id
|
||||
|
||||
return output
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None,
|
||||
already_has_special_tokens: bool = False) -> List[int]:
|
||||
def tokenize(self, text: TextInput, **kwargs) -> List[str]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` method.
|
||||
Converts a string in a sequence of tokens, using the tokenizer.
|
||||
|
||||
Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
|
||||
(BPE/SentencePieces/WordPieces). Takes care of added tokens.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
text (`str`):
|
||||
The sequence to be encoded.
|
||||
**kwargs (additional keyword arguments):
|
||||
Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
|
||||
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
`List[str]`: The list of tokens.
|
||||
"""
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0,
|
||||
token_ids_1=token_ids_1,
|
||||
already_has_special_tokens=True)
|
||||
# Simple mapping string => AddedToken for special tokens with specific tokenization behaviors
|
||||
# all_special_tokens_extended = dict(
|
||||
# (str(t), t)
|
||||
# for t in self.all_special_tokens_extended
|
||||
# if isinstance(t, AddedToken)
|
||||
# )
|
||||
|
||||
bos_token_id = [1] if self.add_bos_token else []
|
||||
eos_token_id = [1] if self.add_eos_token else []
|
||||
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
||||
|
||||
if token_ids_1 is None:
|
||||
return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id
|
||||
return (bos_token_id + ([0] * len(token_ids_0)) + eos_token_id +
|
||||
bos_token_id + ([0] * len(token_ids_1)) + eos_token_id)
|
||||
# TODO: should this be in the base class?
|
||||
if hasattr(self, "do_lower_case") and self.do_lower_case:
|
||||
# convert non-special tokens to lowercase
|
||||
escaped_special_toks = [
|
||||
re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_special_tokens)
|
||||
]
|
||||
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
||||
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None) -> List[int]:
|
||||
"""
|
||||
Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT
|
||||
sequence pair mask has the following format:
|
||||
```
|
||||
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
||||
| first sequence | second sequence |
|
||||
```
|
||||
if token_ids_1 is None, only returns the first portion of the mask (0s).
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
Returns:
|
||||
`List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
|
||||
"""
|
||||
bos_token_id = [self.bos_token_id] if self.add_bos_token else []
|
||||
eos_token_id = [self.eos_token_id] if self.add_eos_token else []
|
||||
no_split_token = set(self.unique_no_split_tokens)
|
||||
tokens = self.tokens_trie.split(text)
|
||||
|
||||
output = [0] * len(bos_token_id + token_ids_0 + eos_token_id)
|
||||
# ["This is something", "<special_token_1>", " else"]
|
||||
# for i, token in enumerate(tokens):
|
||||
# if token in no_split_token:
|
||||
# tok_extended = all_special_tokens_extended.get(token, None)
|
||||
# print(f'>>>{token}|{tok_extended}|{all_special_tokens_extended}<<<')
|
||||
# left = tokens[i - 1] if i > 0 else None
|
||||
# right = tokens[i + 1] if i < len(tokens) - 1 else None
|
||||
# if isinstance(tok_extended, AddedToken):
|
||||
# if tok_extended.rstrip and right:
|
||||
# # A bit counter-intuitive but we strip the left of the string
|
||||
# # since tok_extended.rstrip means the special token is eating all white spaces on its right
|
||||
# tokens[i + 1] = right.lstrip()
|
||||
# # Strip white spaces on the left
|
||||
# if tok_extended.lstrip and left:
|
||||
# tokens[i - 1] = left.rstrip() # Opposite here
|
||||
# else:
|
||||
# We strip left and right by default
|
||||
# if right:
|
||||
# tokens[i + 1] = right.lstrip()
|
||||
# if left:
|
||||
# tokens[i - 1] = left.rstrip()
|
||||
# ["This is something", "<special_token_1>", "else"]
|
||||
tokenized_text = []
|
||||
for token in tokens:
|
||||
# Need to skip eventual empty (fully stripped) tokens
|
||||
if not token:
|
||||
continue
|
||||
if token in no_split_token:
|
||||
tokenized_text.append(token)
|
||||
else:
|
||||
tokenized_text.extend(self._tokenize(token))
|
||||
# ["This", " is", " something", "<special_token_1>", "else"]
|
||||
return tokenized_text
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output += [1] * len(bos_token_id + token_ids_1 + eos_token_id)
|
||||
|
||||
return output
|
||||
def _decode(self, *args, **kwargs):
|
||||
"""doc"""
|
||||
kwargs.pop("clean_up_tokenization_spaces", None)
|
||||
kwargs.pop("spaces_between_special_tokens", None)
|
||||
return super()._decode(
|
||||
*args,
|
||||
**kwargs,
|
||||
clean_up_tokenization_spaces=False,
|
||||
spaces_between_special_tokens=False,
|
||||
)
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Dict,
|
||||
max_length: Optional[int] = None,
|
||||
padding_strategy=PaddingStrategy.DO_NOT_PAD,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
) -> dict:
|
||||
"""doc"""
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
if return_attention_mask:
|
||||
required_input = encoded_inputs[self.model_input_names[0]]
|
||||
if padding_strategy == PaddingStrategy.LONGEST:
|
||||
max_length = len(required_input)
|
||||
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
|
||||
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
|
||||
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
||||
if "attention_mask" in encoded_inputs and encoded_inputs["attention_mask"] is not None:
|
||||
attention_mask = encoded_inputs.pop("attention_mask")
|
||||
if isinstance(attention_mask, paddle.Tensor):
|
||||
attention_mask = attention_mask.numpy()
|
||||
elif isinstance(attention_mask, list):
|
||||
attention_mask = np.array(attention_mask)
|
||||
elif not isinstance(attention_mask, np.ndarray):
|
||||
raise ValueError(f"Unexpected type {type(attention_mask)} of attention_mask, ")
|
||||
else:
|
||||
attention_mask = np.tril(np.ones((len(required_input), len(required_input)), dtype=np.int64))
|
||||
attention_mask = np.expand_dims(attention_mask, axis=0)
|
||||
if needs_to_be_padded:
|
||||
difference = max_length - len(required_input)
|
||||
if self.padding_side == "right":
|
||||
if attention_mask.ndim == 1:
|
||||
pad_width = [(0, difference)]
|
||||
else:
|
||||
pad_width = [(0, 0), (0, difference), (0, difference)]
|
||||
elif self.padding_side == "left":
|
||||
if attention_mask.ndim == 1:
|
||||
pad_width = [(difference, 0)]
|
||||
else:
|
||||
pad_width = [(0, 0), (difference, 0), (difference, 0)]
|
||||
else:
|
||||
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
|
||||
attention_mask = np.pad(
|
||||
attention_mask,
|
||||
pad_width=pad_width,
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
)
|
||||
encoded_inputs = super()._pad(
|
||||
encoded_inputs,
|
||||
max_length,
|
||||
padding_strategy=padding_strategy,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
if return_attention_mask:
|
||||
encoded_inputs["attention_mask"] = attention_mask.tolist()
|
||||
return encoded_inputs
|
||||
|
||||
|
||||
def add_special_tokens(
|
||||
tokenizer,
|
||||
special_tokens_info,
|
||||
use_ocr_specialtoken=False,
|
||||
use_crop_specialtoken=False,
|
||||
special_token_ids_start=254208,
|
||||
special_token_ids_end=256256,
|
||||
):
|
||||
"""
|
||||
增加 special token
|
||||
|
||||
placeholder [<|IMAGE_PLACEHOLDER|>, <|AUDIO_PLACEHOLDER|>, <|VIDEO_PLACEHOLDER|>] 共3个
|
||||
|
||||
模态起始截止 special tokens [<|BOI|> <|EOI|> <|BOA|> <|EOA|> <|BOV|> <|EOV|>]
|
||||
|
||||
ocr special tokens [<|LOC_0|> <|LOC_1|> ... <|LOC_1000|>] 共1001个
|
||||
|
||||
crop special tokens [<|CROP_COL_SEP|>, <|CROP_ROW_SEP|>, <|CROP_IMAGE_SEP|>] 共3个
|
||||
<|CROP_COL_SEP|> for col 维度切 图片width(替换原明文逗号)
|
||||
<|CROP_ROW_SEP|> for row 维度切 图片height(替换原明文回车)
|
||||
<|CROP_IMAGE_SEP|> for 区分原图和crop图 图片width(替换原明文两个回车)
|
||||
|
||||
共2048个 unsed token
|
||||
|
||||
Args:
|
||||
tokenizer (ErnieTokenizer): tokenizer
|
||||
special_token_ids_start (int, optional): special token 起点 ids. Defaults to 254208.
|
||||
special_token_ids_end (int, optional): 词表最多支持大小. Defaults to 256256.
|
||||
"""
|
||||
special_tokens = [
|
||||
special_tokens_info["image_placeholder"],
|
||||
special_tokens_info["audio_placeholder"],
|
||||
]
|
||||
|
||||
if use_ocr_specialtoken:
|
||||
special_tokens.extend(special_tokens_info["ocr_coor"])
|
||||
special_tokens.extend(special_tokens_info["ocr_begin_end"])
|
||||
|
||||
if use_crop_specialtoken:
|
||||
special_tokens.extend(special_tokens_info["crop"])
|
||||
|
||||
# add special_tokens
|
||||
additional_special_tokens = {"additional_special_tokens": special_tokens}
|
||||
tokenizer.add_special_tokens(additional_special_tokens)
|
||||
|
||||
# check
|
||||
first_special_tokens = tokenizer.encode(special_tokens[0])["input_ids"]
|
||||
|
||||
assert first_special_tokens[0] == special_token_ids_start, f"[ERROR] first_special_tokens={first_special_tokens}"
|
||||
assert (
|
||||
len(tokenizer.get_vocab()) < special_token_ids_end
|
||||
), f"[ERROR] vocab_size = {len(tokenizer.get_vocab())} >= {special_token_ids_end} 增加过多special token了!"
|
||||
|
Reference in New Issue
Block a user