mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 16:48:03 +08:00
[Feature]support chat_template.jinja (#3721)
* add support chat_template.jinja * add support chat_template.jinja
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
@@ -162,9 +163,15 @@ def parse_chat_messages(messages):
|
||||
|
||||
def load_chat_template(
|
||||
chat_template: Union[Path, str],
|
||||
model_path: Path = None,
|
||||
is_literal: bool = False,
|
||||
) -> Optional[str]:
|
||||
if chat_template is None:
|
||||
if model_path:
|
||||
chat_template_file = os.path.join(model_path, "chat_template.jinja")
|
||||
if os.path.exists(chat_template_file):
|
||||
with open(chat_template_file) as f:
|
||||
return f.read()
|
||||
return None
|
||||
if is_literal:
|
||||
if isinstance(chat_template, Path):
|
||||
|
@@ -102,7 +102,7 @@ class LLM:
|
||||
self.master_node_ip = self.llm_engine.cfg.master_ip
|
||||
self._receive_output_thread = threading.Thread(target=self._receive_output, daemon=True)
|
||||
self._receive_output_thread.start()
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
self.chat_template = load_chat_template(chat_template, model)
|
||||
|
||||
def _check_master(self):
|
||||
"""
|
||||
|
@@ -80,7 +80,7 @@ parser.add_argument("--max-concurrency", default=512, type=int, help="max concur
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
args.model = retrive_model_from_server(args.model, args.revision)
|
||||
chat_template = load_chat_template(args.chat_template)
|
||||
chat_template = load_chat_template(args.chat_template, args.model)
|
||||
if args.tool_parser_plugin:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
llm_engine = None
|
||||
|
Reference in New Issue
Block a user