mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
@@ -165,7 +165,14 @@ class DataProcessor(BaseDataProcessor):
|
||||
|
||||
self.model_name_or_path = model_name_or_path
|
||||
|
||||
self._init_config()
|
||||
# Generation config
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path)
|
||||
except Exception as e:
|
||||
data_processor_logger.warning(
|
||||
f"Can't find generation config: {e}, so it will not use generation_config field in the model config"
|
||||
)
|
||||
self.generation_config = None
|
||||
|
||||
self.decode_status = dict()
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
@@ -184,30 +191,6 @@ class DataProcessor(BaseDataProcessor):
|
||||
self.reasoning_parser = reasoning_parser_obj(self.tokenizer)
|
||||
self.tokenizer.pad_token_id = self.pad_token_id
|
||||
|
||||
def _init_config(self):
|
||||
"""
|
||||
初始化配置,包括模型名称、使用Hugging Face Tokenizer等。
|
||||
|
||||
Args:
|
||||
无参数,但是会从环境变量中获取一些配置信息。
|
||||
|
||||
Returns:
|
||||
无返回值,直接修改了类的属性。
|
||||
|
||||
Raises:
|
||||
无异常抛出。
|
||||
"""
|
||||
self.use_hf_tokenizer = int(envs.FD_USE_HF_TOKENIZER) == 1
|
||||
|
||||
# Generation config
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_pretrained(self.model_name_or_path)
|
||||
except Exception as e:
|
||||
data_processor_logger.warning(
|
||||
f"Can't find generation config: {e}, so it will not use generation_config field in the model config"
|
||||
)
|
||||
self.generation_config = None
|
||||
|
||||
def process_request(self, request, max_model_len=None, **kwargs):
|
||||
"""
|
||||
Preprocess the request
|
||||
@@ -433,7 +416,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
Returns:
|
||||
List[int]: token ids list
|
||||
"""
|
||||
if self.use_hf_tokenizer:
|
||||
if envs.FD_USE_HF_TOKENIZER:
|
||||
tokens = self.tokenizer(
|
||||
text,
|
||||
return_tensors="np",
|
||||
@@ -491,7 +474,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
Returns:
|
||||
List[str]: strings
|
||||
"""
|
||||
if self.use_hf_tokenizer:
|
||||
if envs.FD_USE_HF_TOKENIZER:
|
||||
if task_id not in self.decode_status:
|
||||
# history token ids & history token strings & befer decode str
|
||||
self.decode_status[task_id] = [[], [], ""]
|
||||
@@ -536,7 +519,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
Returns:
|
||||
tokenizer (AutoTokenizer)
|
||||
"""
|
||||
if self.use_hf_tokenizer:
|
||||
if envs.FD_USE_HF_TOKENIZER:
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
return AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=False)
|
||||
@@ -557,7 +540,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
"""
|
||||
results_all = ""
|
||||
if task_id in self.decode_status:
|
||||
if self.use_hf_tokenizer:
|
||||
if envs.FD_USE_HF_TOKENIZER:
|
||||
results_all = self.decode_status[task_id][2]
|
||||
else:
|
||||
results_all = "".join(self.decode_status[task_id][3])
|
||||
|
Reference in New Issue
Block a user