diff --git a/fastdeploy/entrypoints/chat_utils.py b/fastdeploy/entrypoints/chat_utils.py index c90df1529..eb25122e1 100644 --- a/fastdeploy/entrypoints/chat_utils.py +++ b/fastdeploy/entrypoints/chat_utils.py @@ -14,6 +14,7 @@ # limitations under the License. """ +import os import uuid from copy import deepcopy from pathlib import Path @@ -162,9 +163,15 @@ def parse_chat_messages(messages): def load_chat_template( chat_template: Union[Path, str], + model_path: Path = None, is_literal: bool = False, ) -> Optional[str]: if chat_template is None: + if model_path: + chat_template_file = os.path.join(model_path, "chat_template.jinja") + if os.path.exists(chat_template_file): + with open(chat_template_file) as f: + return f.read() return None if is_literal: if isinstance(chat_template, Path): diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 968306e77..8d88ea3d7 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -102,7 +102,7 @@ class LLM: self.master_node_ip = self.llm_engine.cfg.master_ip self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True) self._receive_output_thread.start() - self.chat_template = load_chat_template(chat_template) + self.chat_template = load_chat_template(chat_template, model) def _check_master(self): """ diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index aceccd837..ede64ad11 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -80,7 +80,7 @@ parser.add_argument("--max-concurrency", default=512, type=int, help="max concur parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() args.model = retrive_model_from_server(args.model, args.revision) -chat_template = load_chat_template(args.chat_template) +chat_template = load_chat_template(args.chat_template, args.model) if args.tool_parser_plugin: ToolParserManager.import_tool_parser(args.tool_parser_plugin) llm_engine = None