[Feature] Pass through the chat_template_kwargs to the data processing module (#3421)

* fix chat_template_args

* fix args

* add offline

* add offline

* fix

* fix

* fix default enable_thinking value

* fix default enable_thinking value

* modify condition

* Revert "modify condition"

This reverts commit 26430bdeb1.

* fix unit test
This commit is contained in:
luukunn
2025-08-19 10:50:01 +08:00
committed by GitHub
parent a053ab889b
commit 3a7a20d191
6 changed files with 50 additions and 13 deletions

View File

@@ -111,7 +111,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
"""process the input data"""
request.chat_template = kwargs.get("chat_template")
task = request.to_dict()
task["enable_thinking"] = kwargs.get("enable_thinking", True)
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs")
self.process_request_dict(task, max_model_len)
request = Request.from_dict(task)
request = self._apply_default_parameters(request)
@@ -218,6 +218,15 @@ class ErnieMoEVLProcessor(ErnieProcessor):
elif request.get("messages"):
messages = request["messages"]
self._check_mm_limits(messages)
chat_template_kwargs = request.get("chat_template_kwargs")
if chat_template_kwargs:
if isinstance(chat_template_kwargs, dict):
for k, v in chat_template_kwargs.items():
if k not in request:
request[k] = v
else:
raise ValueError("Invalid input: chat_template_kwargs must be a dict")
request.setdefault("enable_thinking", True)
outputs = self.ernie_processor.request2ids(request)
else:
raise ValueError(f"Request must contain 'prompt', or 'messages': {request}")