mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature] Support include_stop_str_in_output (#2919)
Co-authored-by: Jiang-Jia-Jun <jiangjiajun@baidu.com>
This commit is contained in:
@@ -100,7 +100,6 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
|
||||
if request.prompt_token_ids is None or len(
|
||||
request.prompt_token_ids) == 0:
|
||||
system = request.get("system")
|
||||
if request.prompt is None and request.messages is None:
|
||||
raise ValueError(
|
||||
f"The request should have `input_ids`, `text` or `messages`: {request}.")
|
||||
@@ -149,7 +148,6 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
request['stop_token_ids'] = stop_seqs
|
||||
request['stop_seqs_len'] = stop_seqs_len
|
||||
|
||||
system = request.get("system")
|
||||
# 处理prompt_token_ids
|
||||
if not request.get('prompt_token_ids'):
|
||||
if request.get('prompt') is None and request.get(
|
||||
@@ -249,7 +247,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
is_end = response_dict["finished"]
|
||||
req_id = response_dict["request_id"]
|
||||
if is_end and len(token_ids) > 0:
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, _, previous_texts = self.ids2tokens(token_ids, req_id)
|
||||
@@ -284,7 +282,7 @@ class ErnieProcessor(BaseDataProcessor):
|
||||
req_id = response_dict["request_id"]
|
||||
token_ids = response_dict["outputs"]["token_ids"]
|
||||
|
||||
if is_end and len(token_ids) > 0:
|
||||
if is_end and len(token_ids) > 0 and not kwargs.get("include_stop_str_in_output"):
|
||||
if token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
token_ids = token_ids[:-1]
|
||||
delta_text, previous_token_ids, previous_texts = self.ids2tokens(
|
||||
|
Reference in New Issue
Block a user