[fix] non-streaming api now returns full output ids if return_token_ids is enabled (#2951)

This commit is contained in:
李泳桦
2025-07-22 14:35:56 +08:00
committed by GitHub
parent 2c6a9e887e
commit 2a8a2c06de
4 changed files with 31 additions and 20 deletions

View File

@@ -138,14 +138,15 @@ class ErnieProcessor(BaseDataProcessor):
request = self._apply_default_parameters(request)
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
# 处理stop_sequences
# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
request["stop_token_ids"] = stop_seqs
request["stop_seqs_len"] = stop_seqs_len
# 处理prompt_token_ids
# processing prompt_token_ids
if not request.get("prompt_token_ids"):
if request.get("prompt") is None and request.get("messages") is None:
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")
@@ -161,7 +162,7 @@ class ErnieProcessor(BaseDataProcessor):
else:
request["prompt_token_ids"] = self.messages2ids(request)
# 截断超过长度限制的prompt
# truncate prompts that exceed the length limit
if max_model_len is not None and len(request["prompt_token_ids"]) > max_model_len:
request["prompt_token_ids"] = request["prompt_token_ids"][: max_model_len - 1]
if request.get("max_tokens") is None: