[Bug fix] Fix the multi-input accuracy issue in the pooling model. (#5374)

* fix multi-inputs

* fix threshold

* fix threshold

* fix
This commit is contained in:
lizexu123
2025-12-05 20:18:17 +08:00
committed by GitHub
parent 96d2d4877b
commit d4979347ca
3 changed files with 107 additions and 6 deletions

View File

@@ -230,13 +230,16 @@ class DataProcessor(BaseDataProcessor):
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
if request.prompt is not None:
prompt = request.prompt
add_special_tokens = request.get("add_special_tokens", False)
assert isinstance(prompt, str) or (
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
if isinstance(prompt, list): # if prompt is a token id list
request.prompt_token_ids = prompt
else:
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
request.prompt_token_ids = self.text2ids(
request.prompt, max_model_len, add_special_tokens=add_special_tokens
)
elif request.messages is not None:
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
@@ -305,13 +308,16 @@ class DataProcessor(BaseDataProcessor):
if not request.get("prompt_token_ids"):
if request.get("prompt"):
prompt = request.get("prompt")
add_special_tokens = request.get("add_special_tokens", False)
assert isinstance(prompt, str) or (
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
if isinstance(prompt, list): # if prompt is a token id list
request["prompt_token_ids"] = prompt
else:
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
request["prompt_token_ids"] = self.text2ids(
request["prompt"], max_model_len, add_special_tokens=add_special_tokens
).tolist()
elif request.get("messages"):
if self.tokenizer.chat_template is None:
raise ValueError("This model does not support chat_template.")
@@ -503,7 +509,7 @@ class DataProcessor(BaseDataProcessor):
**kwargs,
)
def text2ids(self, text, max_model_len):
def text2ids(self, text, max_model_len, **kwargs):
"""
text to token ids
@@ -513,6 +519,8 @@ class DataProcessor(BaseDataProcessor):
Returns:
List[int]: token ids list
"""
add_special_tokens = kwargs.get("add_special_tokens")
if envs.FD_USE_HF_TOKENIZER:
tokens = self.tokenizer(
text,
@@ -529,7 +537,7 @@ class DataProcessor(BaseDataProcessor):
padding=True,
truncation=True,
max_length=max_model_len,
add_special_tokens=False,
add_special_tokens=add_special_tokens,
)
return tokens["input_ids"][0]