This commit is contained in:
luukunn
2025-08-15 15:48:23 +08:00
parent 27c9336812
commit 6320f29ccc
3 changed files with 36 additions and 20 deletions

View File

@@ -108,11 +108,14 @@ class ErnieProcessor(BaseDataProcessor):
data_processor_logger.info(f"req_id:{request.request_id}, tokens:{tokens}, token_ids: {token_ids}")
else:
task = request.to_dict()
chat_template_kwargs = request.get("chat_template_kwargs")
chat_template_kwargs = kwargs.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in task:
task[k] = v
if isinstance(chat_template_kwargs, dict):
for k, v in chat_template_kwargs.items():
if k not in task:
task[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
request.prompt_token_ids = self.messages2ids(task)
if len(request.prompt_token_ids) == 0:
@@ -146,11 +149,7 @@ class ErnieProcessor(BaseDataProcessor):
request = self._apply_default_parameters(request)
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences:
@@ -172,6 +171,14 @@ class ErnieProcessor(BaseDataProcessor):
req_id = request.get("request_id", None)
data_processor_logger.info(f"req_id:{req_id}, tokens:{tokens}, token_ids: {token_ids}")
else:
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
if isinstance(chat_template_kwargs, dict):
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
request["prompt_token_ids"] = self.messages2ids(request)
if len(request["prompt_token_ids"]) == 0:
raise ValueError("Invalid input: prompt_token_ids must be a non-empty sequence of token IDs")

View File

@@ -198,11 +198,6 @@ class ErnieMoEVLProcessor(ErnieProcessor):
request = self._apply_default_parameters(request)
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
stop_sequences = request.get("stop", [])
if stop_sequences:
@@ -222,6 +217,14 @@ class ErnieMoEVLProcessor(ErnieProcessor):
elif request.get("messages"):
messages = request["messages"]
self._check_mm_limits(messages)
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
if isinstance(chat_template_kwargs, dict):
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
outputs = self.ernie_processor.request2ids(request)
else:
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")

View File

@@ -222,9 +222,12 @@ class DataProcessor(BaseDataProcessor):
task = request.to_dict()
chat_template_kwargs = kwargs.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in task:
task[k] = v
if isinstance(chat_template_kwargs, dict):
for k, v in chat_template_kwargs.items():
if k not in task:
task[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
request.prompt_token_ids = self.messages2ids(task)
else:
raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
@@ -276,9 +279,12 @@ class DataProcessor(BaseDataProcessor):
raise ValueError("This model does not support chat_template.")
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
if isinstance(chat_template_kwargs, dict):
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
request["prompt_token_ids"] = self.messages2ids(request)
else:
raise ValueError(f"Request must contain 'prompt_token_ids', 'prompt', or 'messages': {request}")