mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Bug fix] Fix the multi-input accuracy issue in the pooling model. (#5374)
* fix multi-inputs * fix threshold * fix threshold * fix
This commit is contained in:
@@ -230,13 +230,16 @@ class DataProcessor(BaseDataProcessor):
|
||||
if request.prompt_token_ids is None or len(request.prompt_token_ids) == 0:
|
||||
if request.prompt is not None:
|
||||
prompt = request.prompt
|
||||
add_special_tokens = request.get("add_special_tokens", False)
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
if isinstance(prompt, list): # if prompt is a token id list
|
||||
request.prompt_token_ids = prompt
|
||||
else:
|
||||
request.prompt_token_ids = self.text2ids(request.prompt, max_model_len)
|
||||
request.prompt_token_ids = self.text2ids(
|
||||
request.prompt, max_model_len, add_special_tokens=add_special_tokens
|
||||
)
|
||||
elif request.messages is not None:
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
@@ -305,13 +308,16 @@ class DataProcessor(BaseDataProcessor):
|
||||
if not request.get("prompt_token_ids"):
|
||||
if request.get("prompt"):
|
||||
prompt = request.get("prompt")
|
||||
add_special_tokens = request.get("add_special_tokens", False)
|
||||
assert isinstance(prompt, str) or (
|
||||
isinstance(prompt, list) and all([isinstance(t, int) for t in prompt])
|
||||
), f"prompt must be a string or a list of integers, but got {type(prompt)}"
|
||||
if isinstance(prompt, list): # if prompt is a token id list
|
||||
request["prompt_token_ids"] = prompt
|
||||
else:
|
||||
request["prompt_token_ids"] = self.text2ids(request["prompt"], max_model_len).tolist()
|
||||
request["prompt_token_ids"] = self.text2ids(
|
||||
request["prompt"], max_model_len, add_special_tokens=add_special_tokens
|
||||
).tolist()
|
||||
elif request.get("messages"):
|
||||
if self.tokenizer.chat_template is None:
|
||||
raise ValueError("This model does not support chat_template.")
|
||||
@@ -503,7 +509,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def text2ids(self, text, max_model_len):
|
||||
def text2ids(self, text, max_model_len, **kwargs):
|
||||
"""
|
||||
text to token ids
|
||||
|
||||
@@ -513,6 +519,8 @@ class DataProcessor(BaseDataProcessor):
|
||||
Returns:
|
||||
List[int]: token ids list
|
||||
"""
|
||||
|
||||
add_special_tokens = kwargs.get("add_special_tokens")
|
||||
if envs.FD_USE_HF_TOKENIZER:
|
||||
tokens = self.tokenizer(
|
||||
text,
|
||||
@@ -529,7 +537,7 @@ class DataProcessor(BaseDataProcessor):
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=max_model_len,
|
||||
add_special_tokens=False,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
return tokens["input_ids"][0]
|
||||
|
||||
Reference in New Issue
Block a user