[Feature] Support include_stop_str_in_output (#2919)

Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
Jiang-Jia-Jun
2025-07-18 19:43:19 +08:00
committed by GitHub
parent c71d955e9c
commit e421d51001
4 changed files with 80 additions and 15 deletions

View File

@@ -100,7 +100,6 @@ class ErnieProcessor(BaseDataProcessor):
if request.prompt_token_ids is None or len(
request.prompt_token_ids) == 0:
system = request.get("system")
if request.prompt is None and request.messages is None:
raise ValueError(
f"The request should have `input_ids`, `text` or `messages`: {request}.")
@@ -149,7 +148,6 @@ class ErnieProcessor(BaseDataProcessor):
request['stop_token_ids'] = stop_seqs
request['stop_seqs_len'] = stop_seqs_len
system = request.get("system")
# 处理prompt_token_ids
if not request.get('prompt_token_ids'):
if request.get('prompt') is None and request.get(
@@ -249,7 +247,7 @@ class ErnieProcessor(BaseDataProcessor):
token_ids = response_dict["outputs"]["token_ids"]
is_end = response_dict["finished"]
req_id = response_dict["request_id"]
if is_end and len(token_ids) > 0:
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
@@ -284,7 +282,7 @@ class ErnieProcessor(BaseDataProcessor):
req_id = response_dict["request_id"]
token_ids = response_dict["outputs"]["token_ids"]
if is_end and len(token_ids) > 0:
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
if token_ids[-1] == self.tokenizer.eos_token_id:
token_ids = token_ids[:-1]
delta_text, previous_token_ids, previous_texts = self.ids2tokens(