mirror of
				https://github.com/PaddlePaddle/FastDeploy.git
				synced 2025-10-31 20:02:53 +08:00 
			
		
		
		
	[Feature] add cli command chat,complete (#4037)
This commit is contained in:
		
							
								
								
									
										0
									
								
								fastdeploy/entrypoints/cli/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								fastdeploy/entrypoints/cli/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										56
									
								
								fastdeploy/entrypoints/cli/main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								fastdeploy/entrypoints/cli/main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | ||||
| """ | ||||
| # Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License" | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ | ||||
|  | ||||
| # This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/cli/main.py | ||||
| from __future__ import annotations | ||||
|  | ||||
| import importlib.metadata | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     import fastdeploy.entrypoints.cli.openai | ||||
|     from fastdeploy.utils import FlexibleArgumentParser | ||||
|  | ||||
|     CMD_MODULES = [ | ||||
|         fastdeploy.entrypoints.cli.openai, | ||||
|     ] | ||||
|  | ||||
|     parser = FlexibleArgumentParser(description="FastDeploy CLI") | ||||
|     parser.add_argument( | ||||
|         "-v", | ||||
|         "--version", | ||||
|         action="version", | ||||
|         version=importlib.metadata.version("fastdeploy"), | ||||
|     ) | ||||
|     subparsers = parser.add_subparsers(required=False, dest="subparser") | ||||
|     cmds = {} | ||||
|     for cmd_module in CMD_MODULES: | ||||
|         new_cmds = cmd_module.cmd_init() | ||||
|         for cmd in new_cmds: | ||||
|             cmd.subparser_init(subparsers).set_defaults(dispatch_function=cmd.cmd) | ||||
|             cmds[cmd.name] = cmd | ||||
|     args = parser.parse_args() | ||||
|     if args.subparser in cmds: | ||||
|         cmds[args.subparser].validate(args) | ||||
|  | ||||
|     if hasattr(args, "dispatch_function"): | ||||
|         args.dispatch_function(args) | ||||
|     else: | ||||
|         parser.print_help() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										226
									
								
								fastdeploy/entrypoints/cli/openai.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										226
									
								
								fastdeploy/entrypoints/cli/openai.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,226 @@ | ||||
| """ | ||||
| # Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License" | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ | ||||
|  | ||||
| # This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/cli/openai.py | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| import argparse | ||||
| import os | ||||
| import signal | ||||
| import sys | ||||
| from typing import TYPE_CHECKING | ||||
|  | ||||
| from openai import OpenAI | ||||
| from openai.types.chat import ChatCompletionMessageParam | ||||
|  | ||||
| from fastdeploy.entrypoints.cli.types import CLISubcommand | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from fastdeploy.utils import FlexibleArgumentParser | ||||
|  | ||||
|  | ||||
| def _register_signal_handlers(): | ||||
|  | ||||
|     def signal_handler(sig, frame): | ||||
|         sys.exit(0) | ||||
|  | ||||
|     signal.signal(signal.SIGINT, signal_handler) | ||||
|     signal.signal(signal.SIGTSTP, signal_handler) | ||||
|  | ||||
|  | ||||
| def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]: | ||||
|     _register_signal_handlers() | ||||
|  | ||||
|     base_url = args.url | ||||
|     api_key = args.api_key or os.environ.get("OPENAI_API_KEY", "EMPTY") | ||||
|     openai_client = OpenAI(api_key=api_key, base_url=base_url) | ||||
|  | ||||
|     if args.model_name: | ||||
|         model_name = args.model_name | ||||
|     else: | ||||
|         available_models = openai_client.models.list() | ||||
|         model_name = available_models.data[0].id | ||||
|  | ||||
|     print(f"Using model: {model_name}") | ||||
|  | ||||
|     return model_name, openai_client | ||||
|  | ||||
|  | ||||
| def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: | ||||
|     conversation: list[ChatCompletionMessageParam] = [] | ||||
|     if system_prompt is not None: | ||||
|         conversation.append({"role": "system", "content": system_prompt}) | ||||
|  | ||||
|     print("Please enter a message for the chat model:") | ||||
|     while True: | ||||
|         try: | ||||
|             input_message = input("> ") | ||||
|         except EOFError: | ||||
|             break | ||||
|         conversation.append({"role": "user", "content": input_message}) | ||||
|  | ||||
|         chat_completion = client.chat.completions.create(model=model_name, messages=conversation) | ||||
|  | ||||
|         response_message = chat_completion.choices[0].message | ||||
|         output = response_message.content | ||||
|  | ||||
|         conversation.append(response_message)  # type: ignore | ||||
|         print(output) | ||||
|  | ||||
|  | ||||
| def _add_query_options(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: | ||||
|     parser.add_argument( | ||||
|         "--url", | ||||
|         type=str, | ||||
|         default="http://localhost:9904/v1", | ||||
|         help="url of the running OpenAI-Compatible RESTful API server", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--model-name", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help=("The model name used in prompt completion, default to " "the first model in list models API call."), | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--api-key", | ||||
|         type=str, | ||||
|         default=None, | ||||
|         help=( | ||||
|             "API key for OpenAI services. If provided, this api key " | ||||
|             "will overwrite the api key obtained through environment variables." | ||||
|         ), | ||||
|     ) | ||||
|     return parser | ||||
|  | ||||
|  | ||||
| class ChatCommand(CLISubcommand): | ||||
|     """The `chat` subcommand for the fastdeploy CLI.""" | ||||
|  | ||||
|     name = "chat" | ||||
|  | ||||
|     @staticmethod | ||||
|     def cmd(args: argparse.Namespace) -> None: | ||||
|         model_name, client = _interactive_cli(args) | ||||
|         system_prompt = args.system_prompt | ||||
|         conversation: list[ChatCompletionMessageParam] = [] | ||||
|  | ||||
|         if system_prompt is not None: | ||||
|             conversation.append({"role": "system", "content": system_prompt}) | ||||
|  | ||||
|         if args.quick: | ||||
|             conversation.append({"role": "user", "content": args.quick}) | ||||
|  | ||||
|             chat_completion = client.chat.completions.create(model=model_name, messages=conversation) | ||||
|             print(chat_completion.choices[0].message.content) | ||||
|             return | ||||
|  | ||||
|         print("Please enter a message for the chat model:") | ||||
|         while True: | ||||
|             try: | ||||
|                 input_message = input("> ") | ||||
|             except EOFError: | ||||
|                 break | ||||
|             conversation.append({"role": "user", "content": input_message}) | ||||
|  | ||||
|             chat_completion = client.chat.completions.create(model=model_name, messages=conversation) | ||||
|  | ||||
|             response_message = chat_completion.choices[0].message | ||||
|             output = response_message.content | ||||
|  | ||||
|             conversation.append(response_message)  # type: ignore | ||||
|             print(output) | ||||
|  | ||||
|     @staticmethod | ||||
|     def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: | ||||
|         """Add CLI arguments for the chat command.""" | ||||
|         _add_query_options(parser) | ||||
|         parser.add_argument( | ||||
|             "--system-prompt", | ||||
|             type=str, | ||||
|             default=None, | ||||
|             help=( | ||||
|                 "The system prompt to be added to the chat template, " "used for models that support system prompts." | ||||
|             ), | ||||
|         ) | ||||
|         parser.add_argument( | ||||
|             "-q", | ||||
|             "--quick", | ||||
|             type=str, | ||||
|             metavar="MESSAGE", | ||||
|             help=("Send a single prompt as MESSAGE " "and print the response, then exit."), | ||||
|         ) | ||||
|         return parser | ||||
|  | ||||
|     def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: | ||||
|         parser = subparsers.add_parser( | ||||
|             "chat", | ||||
|             help="Generate chat completions via the running API server.", | ||||
|             description="Generate chat completions via the running API server.", | ||||
|             usage="fastdeploy chat [options]", | ||||
|         ) | ||||
|         return ChatCommand.add_cli_args(parser) | ||||
|  | ||||
|  | ||||
| class CompleteCommand(CLISubcommand): | ||||
|     """The `complete` subcommand for the fastdeloy CLI.""" | ||||
|  | ||||
|     name = "complete" | ||||
|  | ||||
|     @staticmethod | ||||
|     def cmd(args: argparse.Namespace) -> None: | ||||
|         model_name, client = _interactive_cli(args) | ||||
|  | ||||
|         if args.quick: | ||||
|             completion = client.completions.create(model=model_name, prompt=args.quick) | ||||
|             print(completion.choices[0].text) | ||||
|             return | ||||
|  | ||||
|         print("Please enter prompt to complete:") | ||||
|         while True: | ||||
|             try: | ||||
|                 input_prompt = input("> ") | ||||
|             except EOFError: | ||||
|                 break | ||||
|             completion = client.completions.create(model=model_name, prompt=input_prompt) | ||||
|             output = completion.choices[0].text | ||||
|             print(output) | ||||
|  | ||||
|     @staticmethod | ||||
|     def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: | ||||
|         """Add CLI arguments for the complete command.""" | ||||
|         _add_query_options(parser) | ||||
|         parser.add_argument( | ||||
|             "-q", | ||||
|             "--quick", | ||||
|             type=str, | ||||
|             metavar="PROMPT", | ||||
|             help="Send a single prompt and print the completion output, then exit.", | ||||
|         ) | ||||
|         return parser | ||||
|  | ||||
|     def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: | ||||
|         parser = subparsers.add_parser( | ||||
|             "complete", | ||||
|             help=("Generate text completions based on the given prompt " "via the running API server."), | ||||
|             description=("Generate text completions based on the given prompt " "via the running API server."), | ||||
|             usage="fastdeploy complete [options]", | ||||
|         ) | ||||
|         return CompleteCommand.add_cli_args(parser) | ||||
|  | ||||
|  | ||||
| def cmd_init() -> list[CLISubcommand]: | ||||
|     return [ChatCommand(), CompleteCommand()] | ||||
							
								
								
									
										42
									
								
								fastdeploy/entrypoints/cli/types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								fastdeploy/entrypoints/cli/types.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| """ | ||||
| # Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License" | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| """ | ||||
|  | ||||
| # This file is modified from https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/cli/types.py | ||||
|  | ||||
| from __future__ import annotations | ||||
|  | ||||
| import argparse | ||||
| import typing | ||||
|  | ||||
| if typing.TYPE_CHECKING: | ||||
|     from fastdeploy.utils import FlexibleArgumentParser | ||||
|  | ||||
|  | ||||
| class CLISubcommand: | ||||
|     """Base class for CLI argument handlers.""" | ||||
|  | ||||
|     name: str | ||||
|  | ||||
|     @staticmethod | ||||
|     def cmd(args: argparse.Namespace) -> None: | ||||
|         raise NotImplementedError("Subclasses should implement this method") | ||||
|  | ||||
|     def validate(self, args: argparse.Namespace) -> None: | ||||
|         # No validation by default | ||||
|         pass | ||||
|  | ||||
|     def subparser_init(self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: | ||||
|         raise NotImplementedError("Subclasses should implement this method") | ||||
							
								
								
									
										3
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								setup.py
									
									
									
									
									
								
							| @@ -238,4 +238,7 @@ setup( | ||||
|     license="Apache 2.0", | ||||
|     python_requires=">=3.7", | ||||
|     extras_require={"test": ["pytest>=6.0"]}, | ||||
|     entry_points={ | ||||
|         "console_scripts": ["fastdeploy=fastdeploy.entrypoints.cli.main:main"], | ||||
|     }, | ||||
| ) | ||||
|   | ||||
							
								
								
									
										26
									
								
								tests/entrypoints/cli/test_main.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								tests/entrypoints/cli/test_main.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| import unittest | ||||
| from unittest.mock import MagicMock, patch | ||||
|  | ||||
| from fastdeploy.entrypoints.cli.main import main as cli_main | ||||
|  | ||||
|  | ||||
| class TestCliMain(unittest.TestCase): | ||||
|     @patch("fastdeploy.utils.FlexibleArgumentParser") | ||||
|     @patch("fastdeploy.entrypoints.cli.main.importlib.metadata") | ||||
|     def test_main_basic(self, mock_metadata, mock_parser): | ||||
|         # Setup mocks | ||||
|         mock_metadata.version.return_value = "1.0.0" | ||||
|         mock_args = MagicMock() | ||||
|         mock_args.subparser = None | ||||
|         mock_parser.return_value.parse_args.return_value = mock_args | ||||
|  | ||||
|         # Test basic call | ||||
|         cli_main() | ||||
|  | ||||
|         # Verify version check | ||||
|         mock_metadata.version.assert_called_once_with("fastdeploy") | ||||
|         mock_args.dispatch_function.assert_called_once() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										201
									
								
								tests/entrypoints/cli/test_openai.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								tests/entrypoints/cli/test_openai.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,201 @@ | ||||
| import argparse | ||||
| import signal | ||||
| import unittest | ||||
| from unittest.mock import MagicMock, call, patch | ||||
|  | ||||
| from fastdeploy.entrypoints.cli.openai import ( | ||||
|     ChatCommand, | ||||
|     CompleteCommand, | ||||
|     _add_query_options, | ||||
|     _interactive_cli, | ||||
|     _register_signal_handlers, | ||||
|     chat, | ||||
|     cmd_init, | ||||
| ) | ||||
|  | ||||
|  | ||||
| class TestOpenAICli(unittest.TestCase): | ||||
|  | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.signal.signal") | ||||
|     def test_register_signal_handlers(self, mock_signal): | ||||
|         """测试信号处理器注册""" | ||||
|         _register_signal_handlers() | ||||
|  | ||||
|         # 验证信号处理器被正确注册 | ||||
|         mock_signal.assert_has_calls([call(signal.SIGINT, unittest.mock.ANY), call(signal.SIGTSTP, unittest.mock.ANY)]) | ||||
|  | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.OpenAI") | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.os.environ.get") | ||||
|     @patch("fastdeploy.entrypoints.cli.openai._register_signal_handlers") | ||||
|     def test_interactive_cli_with_model_name(self, mock_register, mock_environ, mock_openai): | ||||
|         """测试交互式CLI初始化(指定模型名)""" | ||||
|         # 设置mock | ||||
|         mock_environ.return_value = "test_api_key" | ||||
|         mock_client = MagicMock() | ||||
|         mock_openai.return_value = mock_client | ||||
|  | ||||
|         # 测试参数 | ||||
|         args = argparse.Namespace() | ||||
|         args.url = "http://localhost:9904/v1" | ||||
|         args.api_key = None | ||||
|         args.model_name = "test-model" | ||||
|  | ||||
|         # 执行测试 | ||||
|         model_name, client = _interactive_cli(args) | ||||
|  | ||||
|         # 验证结果 | ||||
|         self.assertEqual(model_name, "test-model") | ||||
|         self.assertEqual(client, mock_client) | ||||
|         mock_openai.assert_called_once_with(api_key="test_api_key", base_url="http://localhost:9904/v1") | ||||
|         mock_register.assert_called_once() | ||||
|  | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.OpenAI") | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.os.environ.get") | ||||
|     @patch("fastdeploy.entrypoints.cli.openai._register_signal_handlers") | ||||
|     def test_interactive_cli_without_model_name(self, mock_register, mock_environ, mock_openai): | ||||
|         """测试交互式CLI初始化(未指定模型名)""" | ||||
|         # 设置mock | ||||
|         mock_environ.return_value = "test_api_key" | ||||
|         mock_client = MagicMock() | ||||
|         mock_models = MagicMock() | ||||
|         mock_models.data = [MagicMock(id="first-model")] | ||||
|         mock_client.models.list.return_value = mock_models | ||||
|         mock_openai.return_value = mock_client | ||||
|  | ||||
|         # 测试参数 | ||||
|         args = argparse.Namespace() | ||||
|         args.url = "http://localhost:9904/v1" | ||||
|         args.api_key = None | ||||
|         args.model_name = None | ||||
|  | ||||
|         # 执行测试 | ||||
|         model_name, client = _interactive_cli(args) | ||||
|  | ||||
|         # 验证结果 | ||||
|         self.assertEqual(model_name, "first-model") | ||||
|         self.assertEqual(client, mock_client) | ||||
|         mock_client.models.list.assert_called_once() | ||||
|  | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.input") | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.OpenAI") | ||||
|     def test_chat_function(self, mock_openai, mock_input): | ||||
|         """测试chat函数""" | ||||
|         # 设置mock | ||||
|         mock_client = MagicMock() | ||||
|         mock_completion = MagicMock() | ||||
|         mock_completion.choices = [MagicMock()] | ||||
|         mock_completion.choices[0].message = MagicMock(content="Test response") | ||||
|         mock_client.chat.completions.create.return_value = mock_completion | ||||
|  | ||||
|         # 模拟用户输入和EOF | ||||
|         mock_input.side_effect = ["Hello", EOFError] | ||||
|  | ||||
|         # 执行测试 | ||||
|         chat("System prompt", "test-model", mock_client) | ||||
|  | ||||
|         # 验证API调用 | ||||
|         mock_client.chat.completions.create.assert_called_once() | ||||
|  | ||||
|     def test_add_query_options(self): | ||||
|         """测试查询选项添加""" | ||||
|         mock_parser = MagicMock() | ||||
|  | ||||
|         result = _add_query_options(mock_parser) | ||||
|  | ||||
|         # 验证parser方法被调用 | ||||
|         self.assertEqual(result, mock_parser) | ||||
|         self.assertEqual(mock_parser.add_argument.call_count, 3) | ||||
|  | ||||
|     def test_cmd_init(self): | ||||
|         """测试命令初始化""" | ||||
|         commands = cmd_init() | ||||
|  | ||||
|         # 验证返回的命令列表 | ||||
|         self.assertEqual(len(commands), 2) | ||||
|         self.assertEqual(commands[0].name, "chat") | ||||
|         self.assertEqual(commands[1].name, "complete") | ||||
|  | ||||
|  | ||||
| class TestChatCommand(unittest.TestCase): | ||||
|  | ||||
|     @patch("fastdeploy.entrypoints.cli.openai._interactive_cli") | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.OpenAI") | ||||
|     def test_chat_command_quick_mode(self, mock_openai, mock_interactive): | ||||
|         """测试ChatCommand快速模式""" | ||||
|         # 设置mock | ||||
|         mock_interactive.return_value = ("test-model", MagicMock()) | ||||
|  | ||||
|         args = argparse.Namespace() | ||||
|         args.quick = "Quick message" | ||||
|         args.system_prompt = None | ||||
|  | ||||
|         # 执行测试 | ||||
|         ChatCommand.cmd(args) | ||||
|  | ||||
|         # 验证_interactive_cli被调用 | ||||
|         mock_interactive.assert_called_once_with(args) | ||||
|  | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.input") | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.OpenAI") | ||||
|     def test_chat_empty_input(self, mock_openai, mock_input): | ||||
|         """Test chat with empty input.""" | ||||
|         mock_client = MagicMock() | ||||
|         mock_openai.return_value = mock_client | ||||
|  | ||||
|         # Mock empty input then EOF | ||||
|         mock_input.side_effect = ["", EOFError()] | ||||
|  | ||||
|         args = argparse.Namespace() | ||||
|         args.quick = None | ||||
|         args.url = "http://test.com" | ||||
|         args.api_key = None | ||||
|  | ||||
|         args.model_name = None | ||||
|         args.system_prompt = "System prompt" | ||||
|  | ||||
|         ChatCommand().cmd(args) | ||||
|  | ||||
|  | ||||
| class TestCompleteCommand(unittest.TestCase): | ||||
|  | ||||
|     @patch("fastdeploy.entrypoints.cli.openai._interactive_cli") | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.OpenAI") | ||||
|     def test_complete_command_quick_mode(self, mock_openai, mock_interactive): | ||||
|         """测试CompleteCommand快速模式""" | ||||
|         # 设置mock | ||||
|         mock_client = MagicMock() | ||||
|         mock_completion = MagicMock() | ||||
|         mock_completion.choices = [MagicMock(text="Completion text")] | ||||
|         mock_client.completions.create.return_value = mock_completion | ||||
|         mock_interactive.return_value = ("test-model", mock_client) | ||||
|  | ||||
|         args = argparse.Namespace() | ||||
|         args.quick = "Quick prompt" | ||||
|  | ||||
|         # 执行测试 | ||||
|         CompleteCommand.cmd(args) | ||||
|  | ||||
|         # 验证API调用 | ||||
|         mock_client.completions.create.assert_called_once_with(model="test-model", prompt="Quick prompt") | ||||
|  | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.input") | ||||
|     @patch("fastdeploy.entrypoints.cli.openai.OpenAI") | ||||
|     def test_completion_empty_input(self, mock_openai, mock_input): | ||||
|         """Test completion with empty input.""" | ||||
|         mock_client = MagicMock() | ||||
|         mock_openai.return_value = mock_client | ||||
|  | ||||
|         # Mock empty input then EOF | ||||
|         mock_input.side_effect = ["", EOFError()] | ||||
|  | ||||
|         args = argparse.Namespace() | ||||
|         args.quick = None | ||||
|         args.url = "http://test.com" | ||||
|         args.api_key = None | ||||
|         args.model_name = None | ||||
|  | ||||
|         CompleteCommand.cmd(args) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
							
								
								
									
										42
									
								
								tests/entrypoints/cli/test_types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								tests/entrypoints/cli/test_types.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| import unittest | ||||
| from unittest.mock import MagicMock | ||||
|  | ||||
| from fastdeploy.entrypoints.cli.types import CLISubcommand | ||||
|  | ||||
|  | ||||
| class TestCLISubcommand(unittest.TestCase): | ||||
|     """Test cases for CLISubcommand class.""" | ||||
|  | ||||
|     def test_abstract_methods(self): | ||||
|         """Test that abstract methods raise NotImplementedError.""" | ||||
|         with self.assertRaises(NotImplementedError): | ||||
|             CLISubcommand.cmd(None) | ||||
|  | ||||
|         with self.assertRaises(NotImplementedError): | ||||
|             CLISubcommand().subparser_init(None) | ||||
|  | ||||
|     def test_validate_default_implementation(self): | ||||
|         """Test the default validate implementation does nothing.""" | ||||
|         # Should not raise any exception | ||||
|         CLISubcommand().validate(None) | ||||
|  | ||||
|     def test_name_attribute(self): | ||||
|         """Test that name attribute is required.""" | ||||
|  | ||||
|         class TestSubcommand(CLISubcommand): | ||||
|             name = "test" | ||||
|  | ||||
|             @staticmethod | ||||
|             def cmd(args): | ||||
|                 pass | ||||
|  | ||||
|             def subparser_init(self, subparsers): | ||||
|                 return MagicMock() | ||||
|  | ||||
|         # Should not raise AttributeError | ||||
|         test_cmd = TestSubcommand() | ||||
|         self.assertEqual(test_cmd.name, "test") | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     unittest.main() | ||||
		Reference in New Issue
	
	Block a user
	 memoryCoderC
					memoryCoderC