fix chat_template_args

This commit is contained in:
luukunn
2025-08-15 11:18:15 +08:00
parent 5a84324798
commit 73c97a22fe
5 changed files with 9 additions and 12 deletions

View File

@@ -465,10 +465,7 @@ class LLMEngine:
request.sampling_params = sampling_params
request.preprocess_start_time = time.time()
enable_thinking = None
if kwargs is not None:
enable_thinking = kwargs.get("enable_thinking", None)
request = self.data_processor.process_request(request, self.cfg.max_model_len, enable_thinking=enable_thinking)
request = self.data_processor.process_request(request, self.cfg.max_model_len, **kwargs)
request.prompt_token_ids_len = len(request.prompt_token_ids)
request.need_prefill_tokens = request.prompt_token_ids_len
input_ids_len = request.prompt_token_ids_len

View File

@@ -248,7 +248,7 @@ class LLM:
self,
prompts,
sampling_params,
chat_template_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
):
"""
添加一个请求到 LLM Engine并返回该请求的 ID。
@@ -289,10 +289,7 @@ class LLM:
current_sampling_params = sampling_params[i]
else:
current_sampling_params = sampling_params
enable_thinking = None
if chat_template_kwargs is not None:
enable_thinking = chat_template_kwargs.get("enable_thinking", None)
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
self.llm_engine.add_requests(tasks, current_sampling_params, **kwargs)
return req_ids
def _decode_token(self, token_id: int) -> str:

View File

@@ -90,6 +90,7 @@ class ErnieProcessor(BaseDataProcessor):
request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids
request.enable_thinking = kwargs.get("chat_template_kwargs", {}).get("enable_thinking")
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
@@ -140,7 +141,7 @@ class ErnieProcessor(BaseDataProcessor):
request = self._apply_default_parameters(request)
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
request["enable_thinking"] = request.get("chat_template_kwargs", {}).get("enable_thinking")
# processing stop_sequences
stop_sequences = request.get("stop", [])
if stop_sequences:

View File

@@ -110,7 +110,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
def process_request(self, request, max_model_len=None, **kwargs):
"""process the input data"""
task = request.to_dict()
task["enable_thinking"] = kwargs.get("enable_thinking", True)
task["chat_template_kwargs"] = kwargs.get("chat_template_kwargs", {})
self.process_request_dict(task, max_model_len)
request = Request.from_dict(task)
request = self._apply_default_parameters(request)
@@ -198,6 +198,7 @@ class ErnieMoEVLProcessor(ErnieProcessor):
request = self._apply_default_parameters(request)
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
request["enable_thinking"] = request.get("chat_template_kwargs", {}).get("enable_thinking")
stop_sequences = request.get("stop", [])
if stop_sequences:

View File

@@ -207,7 +207,7 @@ class DataProcessor(BaseDataProcessor):
request = self._apply_default_parameters(request)
if request.get("eos_token_ids") is None or len(request.eos_token_ids) == 0:
request.eos_token_ids = self.eos_token_ids
request.enable_thinking = kwargs.get("chat_template_kwargs", {}).get("enable_thinking")
stop_sequences = request.get("stop", [])
if stop_sequences is not None and len(stop_sequences) != 0:
stop_seqs, stop_seqs_len = self.update_stop_seq(stop_sequences)
@@ -254,6 +254,7 @@ class DataProcessor(BaseDataProcessor):
request = self._apply_default_parameters(request)
if not request.get("eos_token_ids"):
request["eos_token_ids"] = self.eos_token_ids
request["enable_thinking"] = request.get("chat_template_kwargs", {}).get("enable_thinking")
# processing stop_sequences
stop_sequences = request.get("stop", [])