[Feature]support chat_template.jinja (#3721)

* add support chat_template.jinja

* add support chat_template.jinja
This commit is contained in:
luukunn
2025-08-30 17:05:34 +08:00
committed by GitHub
parent b21e085f3e
commit 9a7c231f2c
3 changed files with 9 additions and 2 deletions

View File

@@ -14,6 +14,7 @@
# limitations under the License.
"""
import os
import uuid
from copy import deepcopy
from pathlib import Path
@@ -162,9 +163,15 @@ def parse_chat_messages(messages):
def load_chat_template(
chat_template: Union[Path, str],
model_path: Path = None,
is_literal: bool = False,
) -> Optional[str]:
if chat_template is None:
if model_path:
chat_template_file = os.path.join(model_path, "chat_template.jinja")
if os.path.exists(chat_template_file):
with open(chat_template_file) as f:
return f.read()
return None
if is_literal:
if isinstance(chat_template, Path):

View File

@@ -102,7 +102,7 @@ class LLM:
self.master_node_ip = self.llm_engine.cfg.master_ip
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
self._receive_output_thread.start()
self.chat_template = load_chat_template(chat_template)
self.chat_template = load_chat_template(chat_template, model)
def _check_master(self):
"""

View File

@@ -80,7 +80,7 @@ parser.add_argument("--max-concurrency", default=512, type=int, help="max concur
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
args.model = retrive_model_from_server(args.model, args.revision)
chat_template = load_chat_template(args.chat_template)
chat_template = load_chat_template(args.chat_template, args.model)
if args.tool_parser_plugin:
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
llm_engine = None