mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
[fix] non-streaming api now returns full output ids if return_token_ids is enabled (#2951)
This commit is contained in:
@@ -138,14 +138,15 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request = self._apply_default_parameters(request)
|
||||
if not request.get("eos_token_ids"):
|
||||
request["eos_token_ids"] = self.eos_token_ids
|
||||
# 处理stop_sequences
|
||||
|
||||
# processing stop_sequences
|
||||
stop_sequences = request.get("stop", [])
|
||||
if stop_sequences:
|
||||
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
|
||||
request["stop_token_ids"] = stop_seqs
|
||||
request["stop_seqs_len"] = stop_seqs_len
|
||||
|
||||
# 处理prompt_token_ids
|
||||
# processing prompt_token_ids
|
||||
if not request.get("prompt_token_ids"):
|
||||
if request.get("prompt") is None and request.get("messages") is None:
|
||||
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
|
||||
@@ -161,7 +162,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
else:
|
||||
request["prompt_token_ids"] = self.messages2ids(request)
|
||||
|
||||
# 截断超过长度限制的prompt
|
||||
# truncate prompts that exceed the length limit
|
||||
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
|
||||
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
|
||||
if request.get("max_tokens") is None:
|
||||
|
Reference in New Issue
Block a user