Files
FastDeploy/fastdeploy/input/ernie4_5_tokenizer.py
2025-08-27 20:36:02 +08:00

311 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import re
from shutil import copyfile
from typing import Dict, List, Optional, Tuple
import numpy as np
import paddle
import sentencepiece as spm
from paddleformers.transformers import PretrainedTokenizer
from paddleformers.transformers.tokenizer_utils_base import PaddingStrategy, TextInput
from paddleformers.utils.log import logger
class Ernie4_5Tokenizer(PretrainedTokenizer):
"""
一个更好用的 `ErnieBotToknizer`
能 encode 目前 sft/ppo 阶段的特殊token也支持多模态。
"""
resource_files_names = {
"vocab_file": "tokenizer.model",
}
pretrained_resource_files_map = {"vocab_file": {"ernie-bot-10b": None}}
pretrained_init_configuration = {
"ernie-bot-10b": {},
}
model_input_names = [
"input_ids",
"position_ids",
"attention_mask",
"labels",
]
padding_side = "right"
def __init__(
self,
vocab_file,
bos_token="<s>",
cls_token="<cls>",
eos_token="</s>",
mask_token="<mask:0>",
pad_token="<pad>",
sep_token="<sep>",
unk_token="<unk>",
additional_special_tokens=None,
verbose=False,
**kwargs,
):
"""doc"""
if additional_special_tokens is None:
additional_special_tokens = ["<mask:1>", "<mask:7>"]
super().__init__(
bos_token=bos_token,
cls_token=cls_token,
eos_token=eos_token,
mask_token=mask_token,
pad_token=pad_token,
sep_token=sep_token,
unk_token=unk_token,
additional_special_tokens=additional_special_tokens,
verbose=False,
**kwargs,
)
self.vocab_file = vocab_file
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(vocab_file)
# pre-process map-type all spec token for decode accelerate.
@property
def space_token(self):
"""doc"""
return "<mask:1>"
@property
def space_token_id(self):
"""doc"""
return self.sp_model.piece_to_id("<mask:1>")
@property
def gend_token(self):
"""doc"""
return "<mask:7>"
@property
def gend_token_id(self):
"""doc"""
return self.sp_model.piece_to_id("<mask:7>")
@property
def im_start_id(self):
"""doc"""
return self.sp_model.piece_to_id("<|im_start|>")
@property
def im_end_id(self):
"""doc"""
return self.sp_model.piece_to_id("<|im_end|>")
@property
def vocab_size(self):
"""doc"""
return self.sp_model.vocab_size()
def get_vocab(self):
"""doc"""
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
vocab.update(self.added_tokens_encoder)
return vocab
def _tokenize(self, text):
"""doc"""
return self.sp_model.encode_as_pieces(text)
def _convert_token_to_id(self, token):
"""doc"""
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, id):
"""doc"""
return self.sp_model.id_to_piece(id)
def spec_init(self):
if not hasattr(self, "all_spec_tok"):
self.all_spec_tok = set(self.all_special_tokens)
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
self.spec_init()
current_sub_tokens = []
out_string = ""
# prev_is_special = False
for token in tokens:
# make sure that special tokens are not decoded using sentencepiece model
if token in self.all_spec_tok:
# if not prev_is_special:
# out_string += " "
out_string += self.sp_model.decode(current_sub_tokens) + token
# prev_is_special = True
current_sub_tokens = []
else:
current_sub_tokens.append(token)
# prev_is_special = False
out_string += self.sp_model.decode(current_sub_tokens)
return out_string # .strip()
def prepare_for_model(self, *args, **kwargs):
"""doc"""
if "add_special_tokens" in kwargs:
kwargs.pop("add_special_tokens")
# logger.warning(f'Ernie4_5Tokenizer v2 does not support `add_special_tokens`')
return super().prepare_for_model(*args, **kwargs)
def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
"""
Save the vocabulary and special tokens file to a directory.
Args:
save_directory (`str`):
The directory in which to save the vocabulary.
Returns:
`Tuple(str)`: Paths to the files saved.
"""
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return
out_vocab_file = os.path.join(
save_directory,
(filename_prefix + "-" if filename_prefix else "") + self.resource_files_names["vocab_file"],
)
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
copyfile(self.vocab_file, out_vocab_file)
elif not os.path.isfile(self.vocab_file):
with open(out_vocab_file, "wb") as fi:
content_spiece_model = self.sp_model.serialized_model_proto()
fi.write(content_spiece_model)
return (out_vocab_file,)
def tokenize(self, text: TextInput, **kwargs) -> List[str]:
"""
Converts a string in a sequence of tokens, using the tokenizer.
Split in words for word-based vocabulary or sub-words for sub-word-based vocabularies
(BPE/SentencePieces/WordPieces). Takes care of added tokens.
Args:
text (`str`):
The sequence to be encoded.
**kwargs (additional keyword arguments):
Passed along to the model-specific `prepare_for_tokenization` preprocessing method.
Returns:
`List[str]`: The list of tokens.
"""
self.spec_init()
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
# TODO: should this be in the base class?
if hasattr(self, "do_lower_case") and self.do_lower_case:
# convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in (self.unique_no_split_tokens + self.all_spec_tok)]
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
no_split_token = set(self.unique_no_split_tokens)
tokens = self.tokens_trie.split(text)
tokenized_text = []
for token in tokens:
# Need to skip eventual empty (fully stripped) tokens
if not token:
continue
if token in no_split_token:
tokenized_text.append(token)
else:
tokenized_text.extend(self._tokenize(token))
# ["This", " is", " something", "<special_token_1>", "else"]
return tokenized_text
def _decode(self, *args, **kwargs):
"""doc"""
kwargs.pop("clean_up_tokenization_spaces", None)
kwargs.pop("spaces_between_special_tokens", None)
return super()._decode(
*args,
**kwargs,
clean_up_tokenization_spaces=False,
spaces_between_special_tokens=False,
)
def _pad(
self,
encoded_inputs: Dict,
max_length: Optional[int] = None,
padding_strategy=PaddingStrategy.DO_NOT_PAD,
pad_to_multiple_of: Optional[int] = None,
return_attention_mask: Optional[bool] = None,
) -> dict:
"""doc"""
if return_attention_mask is None:
return_attention_mask = "attention_mask" in self.model_input_names
if return_attention_mask:
required_input = encoded_inputs[self.model_input_names[0]]
if padding_strategy == PaddingStrategy.LONGEST:
max_length = len(required_input)
if max_length is not None and pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
if "attention_mask" in encoded_inputs and encoded_inputs["attention_mask"] is not None:
attention_mask = encoded_inputs.pop("attention_mask")
if isinstance(attention_mask, paddle.Tensor):
attention_mask = attention_mask.numpy()
elif isinstance(attention_mask, list):
attention_mask = np.array(attention_mask)
elif not isinstance(attention_mask, np.ndarray):
raise ValueError(f"Unexpected type {type(attention_mask)} of attention_mask, ")
else:
attention_mask = np.tril(
np.ones(
(len(required_input), len(required_input)),
dtype=np.int64,
)
)
attention_mask = np.expand_dims(attention_mask, axis=0)
if needs_to_be_padded:
difference = max_length - len(required_input)
if self.padding_side == "right":
if attention_mask.ndim == 1:
pad_width = [(0, difference)]
else:
pad_width = [(0, 0), (0, difference), (0, difference)]
elif self.padding_side == "left":
if attention_mask.ndim == 1:
pad_width = [(difference, 0)]
else:
pad_width = [(0, 0), (difference, 0), (difference, 0)]
else:
raise ValueError("Invalid padding strategy:" + str(self.padding_side))
attention_mask = np.pad(
attention_mask,
pad_width=pad_width,
mode="constant",
constant_values=0,
)
encoded_inputs = super()._pad(
encoded_inputs,
max_length,
padding_strategy=padding_strategy,
pad_to_multiple_of=pad_to_multiple_of,
return_attention_mask=False,
)
if return_attention_mask:
encoded_inputs["attention_mask"] = attention_mask.tolist()
return encoded_inputs