From 850465e8ed8e4318aea065191a63a03458f6a352 Mon Sep 17 00:00:00 2001 From: memoryCoderC <1137889088@qq.com> Date: Thu, 11 Sep 2025 19:53:14 +0800 Subject: [PATCH] [Feature] add cli command chat,complete (#4037) --- fastdeploy/entrypoints/cli/__init__.py | 0 fastdeploy/entrypoints/cli/main.py | 56 ++++++ fastdeploy/entrypoints/cli/openai.py | 226 +++++++++++++++++++++++++ fastdeploy/entrypoints/cli/types.py | 42 +++++ setup.py | 3 + tests/entrypoints/cli/test_main.py | 26 +++ tests/entrypoints/cli/test_openai.py | 201 ++++++++++++++++++++++ tests/entrypoints/cli/test_types.py | 42 +++++ 8 files changed, 596 insertions(+) create mode 100644 fastdeploy/entrypoints/cli/__init__.py create mode 100644 fastdeploy/entrypoints/cli/main.py create mode 100644 fastdeploy/entrypoints/cli/openai.py create mode 100644 fastdeploy/entrypoints/cli/types.py create mode 100644 tests/entrypoints/cli/test_main.py create mode 100644 tests/entrypoints/cli/test_openai.py create mode 100644 tests/entrypoints/cli/test_types.py diff --git a/fastdeploy/entrypoints/cli/__init__.py b/fastdeploy/entrypoints/cli/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fastdeploy/entrypoints/cli/main.py b/fastdeploy/entrypoints/cli/main.py new file mode 100644 index 000000000..a4ba74afe --- /dev/null +++ b/fastdeploy/entrypoints/cli/main.py @@ -0,0 +1,56 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/cli/main.py +from __future__ import annotations + +import importlib.metadata + + +def main(): + import fastdeploy.entrypoints.cli.openai + from fastdeploy.utils import FlexibleArgumentParser + + CMD_MODULES = [ + fastdeploy.entrypoints.cli.openai, + ] + + parser = FlexibleArgumentParser(description="FastDeploy CLI") + parser.add_argument( + "-v", + "--version", + action="version", + version=importlib.metadata.version("fastdeploy"), + ) + subparsers = parser.add_subparsers(required=False, dest="subparser") + cmds = {} + for cmd_module in CMD_MODULES: + new_cmds = cmd_module.cmd_init() + for cmd in new_cmds: + cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) + cmds[cmd.name] = cmd + args = parser.parse_args() + if args.subparser in cmds: + cmds[args.subparser].validate(args) + + if hasattr(args, "dispatch_function"): + args.dispatch_function(args) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/fastdeploy/entrypoints/cli/openai.py b/fastdeploy/entrypoints/cli/openai.py new file mode 100644 index 000000000..0ab4c9ae0 --- /dev/null +++ b/fastdeploy/entrypoints/cli/openai.py @@ -0,0 +1,226 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/cli/openai.py + +from __future__ import annotations + +import argparse +import os +import signal +import sys +from typing import TYPE_CHECKING + +from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +from fastdeploy.entrypoints.cli.types import CLISubcommand + +if TYPE_CHECKING: + from fastdeploy.utils import FlexibleArgumentParser + + +def _register_signal_handlers(): + + def signal_handler(sig, frame): + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTSTP, signal_handler) + + +def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]: + _register_signal_handlers() + + base_url = args.url + api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY") + openai_client = OpenAI(api_key=api_key, base_url=base_url) + + if args.model_name: + model_name = args.model_name + else: + available_models = openai_client.models.list() + model_name = available_models.data[0].id + + print(f"Using model: {model_name}") + + return model_name, openai_client + + +def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: + conversation: list[ChatCompletionMessageParam] = [] + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + print("Please enter a message for the chat model:") + while True: + try: + input_message = input("> ") + except EOFError: + break + conversation.append({"role": "user", "content": input_message}) + + chat_completion = client.chat.completions.create(model=model_name, messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) # type: ignore + print(output) + + +def _add_query_options(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + parser.add_argument( + "--url", + type=str, + default="http://localhost:9904/v1", + help="url of the running OpenAI-Compatible RESTful API server", + ) + parser.add_argument( + "--model-name", + type=str, + default=None, + help=("The model name used in prompt completion, default to " "the first model in list models API call."), + ) + parser.add_argument( + "--api-key", + type=str, + default=None, + help=( + "API key for OpenAI services. If provided, this api key " + "will overwrite the api key obtained through environment variables." + ), + ) + return parser + + +class ChatCommand(CLISubcommand): + """The `chat` subcommand for the fastdeploy CLI.""" + + name = "chat" + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + model_name, client = _interactive_cli(args) + system_prompt = args.system_prompt + conversation: list[ChatCompletionMessageParam] = [] + + if system_prompt is not None: + conversation.append({"role": "system", "content": system_prompt}) + + if args.quick: + conversation.append({"role": "user", "content": args.quick}) + + chat_completion = client.chat.completions.create(model=model_name, messages=conversation) + print(chat_completion.choices[0].message.content) + return + + print("Please enter a message for the chat model:") + while True: + try: + input_message = input("> ") + except EOFError: + break + conversation.append({"role": "user", "content": input_message}) + + chat_completion = client.chat.completions.create(model=model_name, messages=conversation) + + response_message = chat_completion.choices[0].message + output = response_message.content + + conversation.append(response_message) # type: ignore + print(output) + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Add CLI arguments for the chat command.""" + _add_query_options(parser) + parser.add_argument( + "--system-prompt", + type=str, + default=None, + help=( + "The system prompt to be added to the chat template, " "used for models that support system prompts." + ), + ) + parser.add_argument( + "-q", + "--quick", + type=str, + metavar="MESSAGE", + help=("Send a single prompt as MESSAGE " "and print the response, then exit."), + ) + return parser + + def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + parser = subparsers.add_parser( + "chat", + help="Generate chat completions via the running API server.", + description="Generate chat completions via the running API server.", + usage="fastdeploy chat [options]", + ) + return ChatCommand.add_cli_args(parser) + + +class CompleteCommand(CLISubcommand): + """The `complete` subcommand for the fastdeloy CLI.""" + + name = "complete" + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + model_name, client = _interactive_cli(args) + + if args.quick: + completion = client.completions.create(model=model_name, prompt=args.quick) + print(completion.choices[0].text) + return + + print("Please enter prompt to complete:") + while True: + try: + input_prompt = input("> ") + except EOFError: + break + completion = client.completions.create(model=model_name, prompt=input_prompt) + output = completion.choices[0].text + print(output) + + @staticmethod + def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: + """Add CLI arguments for the complete command.""" + _add_query_options(parser) + parser.add_argument( + "-q", + "--quick", + type=str, + metavar="PROMPT", + help="Send a single prompt and print the completion output, then exit.", + ) + return parser + + def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + parser = subparsers.add_parser( + "complete", + help=("Generate text completions based on the given prompt " "via the running API server."), + description=("Generate text completions based on the given prompt " "via the running API server."), + usage="fastdeploy complete [options]", + ) + return CompleteCommand.add_cli_args(parser) + + +def cmd_init() -> list[CLISubcommand]: + return [ChatCommand(), CompleteCommand()] diff --git a/fastdeploy/entrypoints/cli/types.py b/fastdeploy/entrypoints/cli/types.py new file mode 100644 index 000000000..b86e2d6b6 --- /dev/null +++ b/fastdeploy/entrypoints/cli/types.py @@ -0,0 +1,42 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +# This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/cli/types.py + +from __future__ import annotations + +import argparse +import typing + +if typing.TYPE_CHECKING: + from fastdeploy.utils import FlexibleArgumentParser + + +class CLISubcommand: + """Base class for CLI argument handlers.""" + + name: str + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + raise NotImplementedError("Subclasses should implement this method") + + def validate(self, args: argparse.Namespace) -> None: + # No validation by default + pass + + def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + raise NotImplementedError("Subclasses should implement this method") diff --git a/setup.py b/setup.py index 1a4f4b2dd..1e9878936 100644 --- a/setup.py +++ b/setup.py @@ -238,4 +238,7 @@ setup( license="Apache 2.0", python_requires=">=3.7", extras_require={"test": ["pytest>=6.0"]}, + entry_points={ + "console_scripts": ["fastdeploy=fastdeploy.entrypoints.cli.main:main"], + }, ) diff --git a/tests/entrypoints/cli/test_main.py b/tests/entrypoints/cli/test_main.py new file mode 100644 index 000000000..dada7f624 --- /dev/null +++ b/tests/entrypoints/cli/test_main.py @@ -0,0 +1,26 @@ +import unittest +from unittest.mock import MagicMock, patch + +from fastdeploy.entrypoints.cli.main import main as cli_main + + +class TestCliMain(unittest.TestCase): + @patch("fastdeploy.utils.FlexibleArgumentParser") + @patch("fastdeploy.entrypoints.cli.main.importlib.metadata") + def test_main_basic(self, mock_metadata, mock_parser): + # Setup mocks + mock_metadata.version.return_value = "1.0.0" + mock_args = MagicMock() + mock_args.subparser = None + mock_parser.return_value.parse_args.return_value = mock_args + + # Test basic call + cli_main() + + # Verify version check + mock_metadata.version.assert_called_once_with("fastdeploy") + mock_args.dispatch_function.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/entrypoints/cli/test_openai.py b/tests/entrypoints/cli/test_openai.py new file mode 100644 index 000000000..81cf79b2c --- /dev/null +++ b/tests/entrypoints/cli/test_openai.py @@ -0,0 +1,201 @@ +import argparse +import signal +import unittest +from unittest.mock import MagicMock, call, patch + +from fastdeploy.entrypoints.cli.openai import ( + ChatCommand, + CompleteCommand, + _add_query_options, + _interactive_cli, + _register_signal_handlers, + chat, + cmd_init, +) + + +class TestOpenAICli(unittest.TestCase): + + @patch("fastdeploy.entrypoints.cli.openai.signal.signal") + def test_register_signal_handlers(self, mock_signal): + """测试信号处理器注册""" + _register_signal_handlers() + + # 验证信号处理器被正确注册 + mock_signal.assert_has_calls([call(signal.SIGINT, unittest.mock.ANY), call(signal.SIGTSTP, unittest.mock.ANY)]) + + @patch("fastdeploy.entrypoints.cli.openai.OpenAI") + @patch("fastdeploy.entrypoints.cli.openai.os.environ.get") + @patch("fastdeploy.entrypoints.cli.openai._register_signal_handlers") + def test_interactive_cli_with_model_name(self, mock_register, mock_environ, mock_openai): + """测试交互式CLI初始化(指定模型名)""" + # 设置mock + mock_environ.return_value = "test_api_key" + mock_client = MagicMock() + mock_openai.return_value = mock_client + + # 测试参数 + args = argparse.Namespace() + args.url = "http://localhost:9904/v1" + args.api_key = None + args.model_name = "test-model" + + # 执行测试 + model_name, client = _interactive_cli(args) + + # 验证结果 + self.assertEqual(model_name, "test-model") + self.assertEqual(client, mock_client) + mock_openai.assert_called_once_with(api_key="test_api_key", base_url="http://localhost:9904/v1") + mock_register.assert_called_once() + + @patch("fastdeploy.entrypoints.cli.openai.OpenAI") + @patch("fastdeploy.entrypoints.cli.openai.os.environ.get") + @patch("fastdeploy.entrypoints.cli.openai._register_signal_handlers") + def test_interactive_cli_without_model_name(self, mock_register, mock_environ, mock_openai): + """测试交互式CLI初始化(未指定模型名)""" + # 设置mock + mock_environ.return_value = "test_api_key" + mock_client = MagicMock() + mock_models = MagicMock() + mock_models.data = [MagicMock(id="first-model")] + mock_client.models.list.return_value = mock_models + mock_openai.return_value = mock_client + + # 测试参数 + args = argparse.Namespace() + args.url = "http://localhost:9904/v1" + args.api_key = None + args.model_name = None + + # 执行测试 + model_name, client = _interactive_cli(args) + + # 验证结果 + self.assertEqual(model_name, "first-model") + self.assertEqual(client, mock_client) + mock_client.models.list.assert_called_once() + + @patch("fastdeploy.entrypoints.cli.openai.input") + @patch("fastdeploy.entrypoints.cli.openai.OpenAI") + def test_chat_function(self, mock_openai, mock_input): + """测试chat函数""" + # 设置mock + mock_client = MagicMock() + mock_completion = MagicMock() + mock_completion.choices = [MagicMock()] + mock_completion.choices[0].message = MagicMock(content="Test response") + mock_client.chat.completions.create.return_value = mock_completion + + # 模拟用户输入和EOF + mock_input.side_effect = ["Hello", EOFError] + + # 执行测试 + chat("System prompt", "test-model", mock_client) + + # 验证API调用 + mock_client.chat.completions.create.assert_called_once() + + def test_add_query_options(self): + """测试查询选项添加""" + mock_parser = MagicMock() + + result = _add_query_options(mock_parser) + + # 验证parser方法被调用 + self.assertEqual(result, mock_parser) + self.assertEqual(mock_parser.add_argument.call_count, 3) + + def test_cmd_init(self): + """测试命令初始化""" + commands = cmd_init() + + # 验证返回的命令列表 + self.assertEqual(len(commands), 2) + self.assertEqual(commands[0].name, "chat") + self.assertEqual(commands[1].name, "complete") + + +class TestChatCommand(unittest.TestCase): + + @patch("fastdeploy.entrypoints.cli.openai._interactive_cli") + @patch("fastdeploy.entrypoints.cli.openai.OpenAI") + def test_chat_command_quick_mode(self, mock_openai, mock_interactive): + """测试ChatCommand快速模式""" + # 设置mock + mock_interactive.return_value = ("test-model", MagicMock()) + + args = argparse.Namespace() + args.quick = "Quick message" + args.system_prompt = None + + # 执行测试 + ChatCommand.cmd(args) + + # 验证_interactive_cli被调用 + mock_interactive.assert_called_once_with(args) + + @patch("fastdeploy.entrypoints.cli.openai.input") + @patch("fastdeploy.entrypoints.cli.openai.OpenAI") + def test_chat_empty_input(self, mock_openai, mock_input): + """Test chat with empty input.""" + mock_client = MagicMock() + mock_openai.return_value = mock_client + + # Mock empty input then EOF + mock_input.side_effect = ["", EOFError()] + + args = argparse.Namespace() + args.quick = None + args.url = "http://test.com" + args.api_key = None + + args.model_name = None + args.system_prompt = "System prompt" + + ChatCommand().cmd(args) + + +class TestCompleteCommand(unittest.TestCase): + + @patch("fastdeploy.entrypoints.cli.openai._interactive_cli") + @patch("fastdeploy.entrypoints.cli.openai.OpenAI") + def test_complete_command_quick_mode(self, mock_openai, mock_interactive): + """测试CompleteCommand快速模式""" + # 设置mock + mock_client = MagicMock() + mock_completion = MagicMock() + mock_completion.choices = [MagicMock(text="Completion text")] + mock_client.completions.create.return_value = mock_completion + mock_interactive.return_value = ("test-model", mock_client) + + args = argparse.Namespace() + args.quick = "Quick prompt" + + # 执行测试 + CompleteCommand.cmd(args) + + # 验证API调用 + mock_client.completions.create.assert_called_once_with(model="test-model", prompt="Quick prompt") + + @patch("fastdeploy.entrypoints.cli.openai.input") + @patch("fastdeploy.entrypoints.cli.openai.OpenAI") + def test_completion_empty_input(self, mock_openai, mock_input): + """Test completion with empty input.""" + mock_client = MagicMock() + mock_openai.return_value = mock_client + + # Mock empty input then EOF + mock_input.side_effect = ["", EOFError()] + + args = argparse.Namespace() + args.quick = None + args.url = "http://test.com" + args.api_key = None + args.model_name = None + + CompleteCommand.cmd(args) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/entrypoints/cli/test_types.py b/tests/entrypoints/cli/test_types.py new file mode 100644 index 000000000..22b099855 --- /dev/null +++ b/tests/entrypoints/cli/test_types.py @@ -0,0 +1,42 @@ +import unittest +from unittest.mock import MagicMock + +from fastdeploy.entrypoints.cli.types import CLISubcommand + + +class TestCLISubcommand(unittest.TestCase): + """Test cases for CLISubcommand class.""" + + def test_abstract_methods(self): + """Test that abstract methods raise NotImplementedError.""" + with self.assertRaises(NotImplementedError): + CLISubcommand.cmd(None) + + with self.assertRaises(NotImplementedError): + CLISubcommand().subparser_init(None) + + def test_validate_default_implementation(self): + """Test the default validate implementation does nothing.""" + # Should not raise any exception + CLISubcommand().validate(None) + + def test_name_attribute(self): + """Test that name attribute is required.""" + + class TestSubcommand(CLISubcommand): + name = "test" + + @staticmethod + def cmd(args): + pass + + def subparser_init(self, subparsers): + return MagicMock() + + # Should not raise AttributeError + test_cmd = TestSubcommand() + self.assertEqual(test_cmd.name, "test") + + +if __name__ == "__main__": + unittest.main()