From f15edbb6efa9efa1185a931687aef51ac7041db2 Mon Sep 17 00:00:00 2001 From: kesmeey <107767849+kesmeey@users.noreply.github.com> Date: Tue, 23 Dec 2025 18:06:43 +0800 Subject: [PATCH] =?UTF-8?q?[CI]=E3=80=90Hackathon=209th=20Sprint=20No.40?= =?UTF-8?q?=E3=80=91=E5=8A=9F=E8=83=BD=E6=A8=A1=E5=9D=97=20fastdeploy/entr?= =?UTF-8?q?ypoints/openai/api=5Fserver.py=20=E5=8D=95=E6=B5=8B=E8=A1=A5?= =?UTF-8?q?=E5=85=85=20(#5567)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add tests for openai api_server coverage * update * Update tests for openai api_server * fix bugs * test: disable some api_server lifespan/controller tests for local env * Format test_api_server with black * update * update * test: narrow envs patch in api_server tests to avoid side effects * fix: separate MagicMock creation to avoid missing req argument * fix: patch TRACES_ENABLE env var in api_server tests * fix: use os.environ patch for TRACES_ENABLE * test: use fake fastdeploy.envs in api_server tests * test: pass fake Request into chat/completion routes * test: increase coverage for tracing and scheduler control * fix: set dynamic_load_weight in tracing headers test * ci: add retry and validation for FastDeploy.tar.gz download * ci: fix indentation in _base_test.yml * refactor: simplify test_api_server.py (807->480 lines, ~40% reduction) * fix: restore missing args attributes (revision, etc.) in _build_args * fix: patch sys.argv to prevent SystemExit: 2 in api_server tests * improve coverage * Remove docstring from test_api_server.py Removed unnecessary docstring from test_api_server.py --------- Co-authored-by: CSWYF3634076 --- .github/workflows/_base_test.yml | 69 +- tests/entrypoints/openai/test_api_server.py | 711 ++++++++++++++++++++ 2 files changed, 758 insertions(+), 22 deletions(-) create mode 100644 tests/entrypoints/openai/test_api_server.py diff --git a/.github/workflows/_base_test.yml b/.github/workflows/_base_test.yml index 851c0648b..4087a50ff 100644 --- a/.github/workflows/_base_test.yml +++ b/.github/workflows/_base_test.yml @@ -39,29 +39,54 @@ jobs: docker_image: ${{ inputs.DOCKER_IMAGE }} fd_archive_url: ${{ inputs.FASTDEPLOY_ARCHIVE_URL }} run: | - set -x - REPO="https://github.com/${{ github.repository }}.git" - FULL_REPO="${{ github.repository }}" - REPO_NAME="${FULL_REPO##*/}" - BASE_BRANCH="${{ github.base_ref }}" - docker pull ${docker_image} - # Clean the repository directory before starting - docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ - -e "REPO_NAME=${REPO_NAME}" \ - ${docker_image} /bin/bash -c ' - if [ -d ${REPO_NAME} ]; then - echo "Directory ${REPO_NAME} exists, removing it..." - rm -rf ${REPO_NAME}* - fi - ' + set -x + REPO="https://github.com/${{ github.repository }}.git" + FULL_REPO="${{ github.repository }}" + REPO_NAME="${FULL_REPO##*/}" + BASE_BRANCH="${{ github.base_ref }}" + docker pull ${docker_image} + # Clean the repository directory before starting + docker run --rm --net=host -v $(pwd):/workspace -w /workspace \ + -e "REPO_NAME=${REPO_NAME}" \ + ${docker_image} /bin/bash -c ' + if [ -d ${REPO_NAME} ]; then + echo "Directory ${REPO_NAME} exists, removing it..." + rm -rf ${REPO_NAME}* + fi + ' - wget -q --no-proxy ${fd_archive_url} - tar -xf FastDeploy.tar.gz - rm -rf FastDeploy.tar.gz - cd FastDeploy - git config --global user.name "FastDeployCI" - git config --global user.email "fastdeploy_ci@example.com" - git log -n 3 --oneline + # Download with retry and validation + MAX_RETRIES=3 + RETRY_COUNT=0 + while [ $RETRY_COUNT -lt $MAX_RETRIES ]; do + if wget -q --no-proxy ${fd_archive_url} && [ -f FastDeploy.tar.gz ] && [ -s FastDeploy.tar.gz ]; then + echo "Download successful, file size: $(stat -c%s FastDeploy.tar.gz) bytes" + break + else + RETRY_COUNT=$((RETRY_COUNT + 1)) + echo "Download failed or file is empty, retry $RETRY_COUNT/$MAX_RETRIES..." + rm -f FastDeploy.tar.gz + sleep 2 + fi + done + + if [ ! -f FastDeploy.tar.gz ] || [ ! -s FastDeploy.tar.gz ]; then + echo "ERROR: Failed to download FastDeploy.tar.gz after $MAX_RETRIES attempts" + exit 1 + fi + + # Verify tar.gz integrity before extraction + if ! tar -tzf FastDeploy.tar.gz > /dev/null 2>&1; then + echo "ERROR: FastDeploy.tar.gz is corrupted or incomplete" + exit 1 + fi + + tar -xf FastDeploy.tar.gz + rm -rf FastDeploy.tar.gz + cd FastDeploy + git config --global user.name "FastDeployCI" + git config --global user.email "fastdeploy_ci@example.com" + git log -n 3 --oneline - name: Run FastDeploy Base Tests shell: bash diff --git a/tests/entrypoints/openai/test_api_server.py b/tests/entrypoints/openai/test_api_server.py new file mode 100644 index 000000000..40d8da93f --- /dev/null +++ b/tests/entrypoints/openai/test_api_server.py @@ -0,0 +1,711 @@ +import asyncio +import importlib +import sys +import types +from contextlib import ExitStack +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +# Shared test fixtures +from fastdeploy.entrypoints.openai.protocol import ( + ChatCompletionResponse, + CompletionResponse, + ErrorInfo, + ErrorResponse, + ModelInfo, + ModelList, + UsageInfo, +) + + +class DummyErrorInfo: + def __init__(self, message: str, code=None, **_): + self.message = message + self.code = str(code) if code is not None else code + + +class DummyErrorResponse: + def __init__(self, error): + self.error = error + + def model_dump(self): + return {"error": {"message": self.error.message, "code": self.error.code}} + + +def _build_args(**overrides): + """Return a SimpleNamespace with all attributes accessed at import time.""" + base = dict( + workers=1, + model="test-model", + revision=None, + chat_template=None, + tool_parser_plugin=None, + host="0.0.0.0", + port=9000, + metrics_port=None, + controller_port=-1, + max_concurrency=4, + max_model_len=1024, + max_waiting_time=-1, + max_logprobs=0, + tensor_parallel_size=1, + data_parallel_size=1, + max_num_seqs=8, + api_key=None, + tokenizer=None, + served_model_name=None, + ips=None, + enable_mm_output=False, + tokenizer_base_url=None, + dynamic_load_weight=False, + timeout_graceful_shutdown=0, + timeout=0, + local_data_parallel_id=0, + ) + base.update(overrides) + return SimpleNamespace(**base) + + +def _reload_api_server(args): + """Import/reload api_server with patched dependencies.""" + fake_envs_mod = types.ModuleType("fastdeploy.envs") + + class _FakeEnvVars: + @staticmethod + def get(key, default=None): + return [] if key == "FD_API_KEY" else default + + fake_envs_mod.TRACES_ENABLE = "false" + fake_envs_mod.FD_SERVICE_NAME = "" + fake_envs_mod.FD_HOST_NAME = "" + fake_envs_mod.TRACES_EXPORTER = "console" + fake_envs_mod.EXPORTER_OTLP_ENDPOINT = "" + fake_envs_mod.EXPORTER_OTLP_HEADERS = "" + fake_envs_mod.environment_variables = _FakeEnvVars() + + # Save original sys.argv and replace with minimal valid args to avoid parse errors + original_argv = sys.argv[:] + sys.argv = ["api_server.py", "--model", "test-model", "--port", "9000"] + + try: + with ( + patch("fastdeploy.utils.FlexibleArgumentParser.parse_args", return_value=args), + patch("fastdeploy.utils.retrive_model_from_server", return_value=args.model), + patch("fastdeploy.entrypoints.chat_utils.load_chat_template", return_value=None), + patch.dict("sys.modules", {"fastdeploy.envs": fake_envs_mod}), + patch("fastdeploy.envs", fake_envs_mod), + ): + from fastdeploy.entrypoints.openai import api_server as api_server_mod + + return importlib.reload(api_server_mod) + finally: + sys.argv = original_argv + + +def _dummy_engine_args(config_parallel_id=0): + cfg = SimpleNamespace(parallel_config=SimpleNamespace(local_data_parallel_id=config_parallel_id)) + + class DummyArgs: + def create_engine_config(self, port_availability_check=True): + return cfg + + return DummyArgs() + + +def _dummy_engine_client(): + class DummyConnMgr: + async def initialize(self): + pass + + async def close(self): + pass + + class DummyClient: + def __init__(self, *_, **__): + self.connection_manager = DummyConnMgr() + self.zmq_client = SimpleNamespace(close=lambda: None) + self.data_processor = "dp" + self.pid = None + + def create_zmq_client(self, *_, **__): + self.zmq_client = SimpleNamespace(close=lambda: None) + + def check_health(self): + return True, "ok" + + def is_workers_alive(self): + return True, "ok" + + async def rearrange_experts(self, request_dict): + return {"data": request_dict}, 201 + + async def get_per_expert_tokens_stats(self, request_dict): + return {"stats": request_dict}, 202 + + async def check_redundant(self, request_dict): + return {"redundant": request_dict}, 203 + + return DummyClient + + +def _fake_handlers(): + class Handler: + def __init__(self, *_, **__): + pass + + async def create_chat_completion(self, *args, **kwargs): + return args[0] if args else None + + async def create_completion(self, *args, **kwargs): + return args[0] if args else None + + async def create_embedding(self, req): + return SimpleNamespace(model_dump=lambda: {"emb": True}) + + async def create_reward(self, req): + return SimpleNamespace(model_dump=lambda: {"reward": True}) + + async def list_models(self): + return SimpleNamespace(model_dump=lambda: {"list": True}) + + return Handler + + +def _patch_common_imports(args, engine_client_cls=None, handler_cls=None): + engine_client_cls = engine_client_cls or _dummy_engine_client() + handler_cls = handler_cls or _fake_handlers() + + fake_paddle = types.ModuleType("paddle") + fake_prom = types.ModuleType("prometheus_client") + fake_prom.multiprocess = SimpleNamespace(mark_process_dead=lambda *_: None) + fake_metrics = types.ModuleType("fastdeploy.metrics.metrics") + fake_metrics.get_filtered_metrics = lambda: "" + fake_metrics_pkg = types.ModuleType("fastdeploy.metrics") + fake_metrics_pkg.metrics = fake_metrics + fake_zmq = types.ModuleType("zmq") + fake_zmq.PUSH = "PUSH" + + stack = ExitStack() + stack.enter_context(patch.dict("sys.modules", {"paddle": fake_paddle})) + stack.enter_context(patch.dict("sys.modules", {"prometheus_client": fake_prom})) + stack.enter_context(patch.dict("sys.modules", {"fastdeploy.metrics": fake_metrics_pkg})) + stack.enter_context(patch.dict("sys.modules", {"fastdeploy.metrics.metrics": fake_metrics})) + stack.enter_context(patch.dict("sys.modules", {"zmq": fake_zmq})) + stack.enter_context( + patch("fastdeploy.entrypoints.openai.api_server.EngineArgs.from_cli_args", return_value=_dummy_engine_args()) + ) + stack.enter_context(patch("fastdeploy.entrypoints.openai.api_server.EngineClient", engine_client_cls)) + stack.enter_context(patch("fastdeploy.entrypoints.openai.api_server.OpenAIServingModels", handler_cls)) + stack.enter_context(patch("fastdeploy.entrypoints.openai.api_server.OpenAIServingChat", handler_cls)) + stack.enter_context(patch("fastdeploy.entrypoints.openai.api_server.OpenAIServingCompletion", handler_cls)) + stack.enter_context(patch("fastdeploy.entrypoints.openai.api_server.OpenAIServingEmbedding", handler_cls)) + stack.enter_context(patch("fastdeploy.entrypoints.openai.api_server.OpenAIServingReward", handler_cls)) + stack.enter_context(patch("fastdeploy.entrypoints.openai.api_server.ToolParserManager.import_tool_parser")) + return stack + + +def test_tool_parser_and_load_engine_branches(): + args = _build_args(tool_parser_plugin="plugin") + with ( + patch("fastdeploy.utils.FlexibleArgumentParser.parse_args", return_value=args), + patch("fastdeploy.utils.retrive_model_from_server", return_value=args.model), + patch("fastdeploy.entrypoints.chat_utils.load_chat_template", return_value=None), + patch("fastdeploy.entrypoints.openai.api_server.ToolParserManager.import_tool_parser") as import_mock, + patch("fastdeploy.entrypoints.openai.api_server.LLMEngine.from_engine_args") as llm_from_args, + patch("fastdeploy.entrypoints.openai.api_server.EngineArgs.from_cli_args", return_value=_dummy_engine_args()), + ): + from fastdeploy.entrypoints.openai import api_server as api_server_mod + + api_server = importlib.reload(api_server_mod) + import_mock.assert_called_once() + + api_server.llm_engine = "cached" + assert api_server.load_engine() == "cached" + + api_server.llm_engine = None + llm_from_args.return_value = SimpleNamespace(start=MagicMock(return_value=False)) + assert api_server.load_engine() is None + + with patch.object(api_server_mod.BaseApplication, "__init__", return_value=None): + app_instance = api_server_mod.StandaloneApplication("app", {"bind": "0.0.0.0:1", "unused": None}) + app_instance.cfg = SimpleNamespace(settings={"bind": True}) + app_instance.cfg.set = MagicMock() + app_instance.load_config() + app_instance.cfg.set.assert_called_once() + assert app_instance.load() == "app" + + +def test_load_data_service_branches(): + args = _build_args() + api_server = _reload_api_server(args) + cfg = SimpleNamespace(parallel_config=SimpleNamespace(local_data_parallel_id=1)) + engine_args = SimpleNamespace(create_engine_config=lambda: cfg) + expert = MagicMock() + expert.start.side_effect = [False, True] + + with ( + patch("fastdeploy.entrypoints.openai.api_server.EngineArgs.from_cli_args", return_value=engine_args), + patch("fastdeploy.entrypoints.openai.api_server.ExpertService", return_value=expert), + ): + api_server.llm_engine = None + assert api_server.load_data_service() is None + api_server.llm_engine = None + assert api_server.load_data_service() is expert + assert api_server.load_data_service() is expert + + +@pytest.mark.asyncio +async def test_connection_manager_timeout_branch(): + args = _build_args() + api_server = _reload_api_server(args) + + class SlowSemaphore: + async def acquire(self): + await asyncio.sleep(0.01) + + def status(self): + return "busy" + + with patch("fastdeploy.entrypoints.openai.api_server.connection_semaphore", SlowSemaphore()): + with pytest.raises(api_server.HTTPException) as exc: + async with api_server.connection_manager(): + pass + assert exc.value.status_code == 429 + + +def test_health_and_routes(): + args = _build_args() + api_server = _reload_api_server(args) + engine_client = MagicMock() + engine_client.check_health.return_value = (True, "ok") + engine_client.is_workers_alive.return_value = (False, "dead") + api_server.app.state.engine_client = engine_client + + assert api_server.health(MagicMock()).status_code == 304 + assert api_server.ping(MagicMock()).status_code == 304 + + engine_client.is_workers_alive.return_value = (True, "ok") + assert api_server.health(MagicMock()).status_code == 200 + + routes = asyncio.run(api_server.list_all_routes()) + assert isinstance(routes, dict) and routes["routes"] + + +@pytest.mark.asyncio +async def test_wrap_streaming_generator(): + args = _build_args() + api_server = _reload_api_server(args) + sem = MagicMock() + + # Error path with span + span = MagicMock() + span.is_recording.return_value = True + with ( + patch("opentelemetry.trace.get_current_span", return_value=span), + patch("fastdeploy.entrypoints.openai.api_server.connection_semaphore", sem), + ): + + async def gen(): + yield "first" + raise RuntimeError("boom") + + wrapped = api_server.wrap_streaming_generator(gen()) + with pytest.raises(RuntimeError): + async for _ in wrapped(): + pass + span.record_exception.assert_called() + sem.release.assert_called_once() + + # Success path without span + api_server.connection_semaphore = SimpleNamespace(status=lambda: "ok", release=MagicMock()) + with patch("fastdeploy.entrypoints.openai.api_server.trace.get_current_span", return_value=None): + + async def gen2(): + yield "a" + yield "b" + + wrapped = api_server.wrap_streaming_generator(gen2()) + out = [] + async for item in wrapped(): + out.append(item) + assert out == ["a", "b"] + api_server.connection_semaphore.release.assert_called_once() + + # Success path with span and last_chunk event (count > 0) + span = MagicMock() + span.is_recording.return_value = True + api_server.connection_semaphore = SimpleNamespace(status=lambda: "ok", release=MagicMock()) + with patch("opentelemetry.trace.get_current_span", return_value=span): + + async def gen3(): + yield "chunk1" + yield "chunk2" + + wrapped = api_server.wrap_streaming_generator(gen3()) + out = [] + async for item in wrapped(): + out.append(item) + assert out == ["chunk1", "chunk2"] + span.add_event.assert_called() + api_server.connection_semaphore.release.assert_called_once() + + +@pytest.mark.asyncio +async def test_chat_and_completion_routes(): + args = _build_args(dynamic_load_weight=True) + api_server = _reload_api_server(args) + api_server.app.state.dynamic_load_weight = True + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.is_workers_alive.return_value = (False, "down") + fake_req = SimpleNamespace(headers={}) + body = SimpleNamespace(model_dump_json=lambda: "{}", stream=False) + + # Unhealthy path + resp = await api_server.create_chat_completion(body, fake_req) + assert resp.status_code == 304 + resp = await api_server.create_completion(body, fake_req) + assert resp.status_code == 304 + + # Healthy path with dynamic_load_weight=True (missing branch 383, 419) + api_server.app.state.dynamic_load_weight = True + api_server.app.state.engine_client.is_workers_alive.return_value = (True, "ok") + api_server.connection_semaphore = SimpleNamespace(acquire=AsyncMock(), release=MagicMock(), status=lambda: "ok") + success_resp = ChatCompletionResponse(id="1", model="m", choices=[], usage=UsageInfo()) + api_server.app.state.chat_handler = SimpleNamespace(create_chat_completion=AsyncMock(return_value=success_resp)) + api_server.app.state.completion_handler = SimpleNamespace(create_completion=AsyncMock(return_value=success_resp)) + + class DummyCM: + async def __aenter__(self): + return None + + async def __aexit__(self, exc_type, exc, tb): + return False + + with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()): + resp = await api_server.create_chat_completion(body, fake_req) + assert resp.status_code == 200 + resp = await api_server.create_completion(body, fake_req) + assert resp.status_code == 200 + + # Healthy paths + api_server.app.state.dynamic_load_weight = False + api_server.connection_semaphore = SimpleNamespace(acquire=AsyncMock(), release=MagicMock(), status=lambda: "ok") + + error_resp = ErrorResponse(error=ErrorInfo(message="err")) + chat_handler = MagicMock() + chat_handler.create_chat_completion = AsyncMock(return_value=error_resp) + api_server.app.state.chat_handler = chat_handler + assert (await api_server.create_chat_completion(body, fake_req)).status_code == 500 + + success_resp = ChatCompletionResponse(id="1", model="m", choices=[], usage=UsageInfo()) + api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=success_resp) + assert (await api_server.create_chat_completion(body, fake_req)).status_code == 200 + + async def stream_gen(): + yield "data" + + api_server.app.state.chat_handler.create_chat_completion = AsyncMock(return_value=stream_gen()) + assert isinstance(await api_server.create_chat_completion(body, fake_req), api_server.StreamingResponse) + + # Completion handler + completion_handler = MagicMock() + completion_handler.create_completion = AsyncMock(return_value=error_resp) + api_server.app.state.completion_handler = completion_handler + assert (await api_server.create_completion(body, fake_req)).status_code == 500 + + api_server.app.state.completion_handler.create_completion = AsyncMock(return_value=success_resp) + assert (await api_server.create_completion(body, fake_req)).status_code == 200 + + # HTTPException handling + class RaiseHTTP: + async def __aenter__(self): + raise api_server.HTTPException(status_code=418, detail="teapot") + + async def __aexit__(self, exc_type, exc, tb): + return False + + with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=RaiseHTTP()): + assert (await api_server.create_chat_completion(body, fake_req)).status_code == 418 + assert (await api_server.create_completion(body, fake_req)).status_code == 418 + + +@pytest.mark.asyncio +async def test_chat_completion_tracing(): + args = _build_args(dynamic_load_weight=False) + api_server = _reload_api_server(args) + api_server.envs.TRACES_ENABLE = "true" + api_server.app.state.dynamic_load_weight = False + + fake_req = SimpleNamespace(headers={"x-request-id": "1"}) + body = SimpleNamespace(model_dump_json=lambda: "{}", stream=False) + + chat_resp = ChatCompletionResponse(id="1", model="m", choices=[], usage=UsageInfo()) + completion_resp = CompletionResponse(id="2", model="m", choices=[], usage=UsageInfo()) + + api_server.app.state.chat_handler = SimpleNamespace(create_chat_completion=AsyncMock(return_value=chat_resp)) + api_server.app.state.completion_handler = SimpleNamespace( + create_completion=AsyncMock(return_value=completion_resp) + ) + api_server.connection_semaphore = SimpleNamespace(acquire=AsyncMock(), release=MagicMock(), status=lambda: "ok") + + class DummyCM: + async def __aenter__(self): + return None + + async def __aexit__(self, exc_type, exc, tb): + return False + + with ( + patch("fastdeploy.entrypoints.openai.api_server.extract", return_value="ctx"), + patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()), + ): + resp_chat = await api_server.create_chat_completion(body, fake_req) + resp_comp = await api_server.create_completion(body, fake_req) + + assert resp_chat.status_code == 200 + assert resp_comp.status_code == 200 + assert getattr(body, "trace_context", None) == "ctx" + + # TRACES_ENABLE=True but req.headers is None/empty (missing branch 379, 415) + api_server.envs.TRACES_ENABLE = "true" + fake_req_no_headers = SimpleNamespace(headers=None) + body2 = SimpleNamespace(model_dump_json=lambda: "{}", stream=False) + api_server.app.state.chat_handler = SimpleNamespace(create_chat_completion=AsyncMock(return_value=chat_resp)) + api_server.app.state.completion_handler = SimpleNamespace( + create_completion=AsyncMock(return_value=completion_resp) + ) + + with patch("fastdeploy.entrypoints.openai.api_server.connection_manager", return_value=DummyCM()): + resp_chat2 = await api_server.create_chat_completion(body2, fake_req_no_headers) + resp_comp2 = await api_server.create_completion(body2, fake_req_no_headers) + + assert resp_chat2.status_code == 200 + assert resp_comp2.status_code == 200 + assert not hasattr(body2, "trace_context") + + +@pytest.mark.asyncio +async def test_reward_embedding_and_weights(): + args = _build_args(dynamic_load_weight=True) + api_server = _reload_api_server(args) + api_server.app.state.dynamic_load_weight = True + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.is_workers_alive.return_value = (False, "down") + + assert (await api_server.create_reward(SimpleNamespace())).status_code == 304 + assert (await api_server.create_embedding(SimpleNamespace())).status_code == 304 + + api_server.app.state.dynamic_load_weight = False + api_server.app.state.reward_handler = MagicMock( + create_reward=AsyncMock(return_value=SimpleNamespace(model_dump=lambda: {"ok": True})) + ) + api_server.app.state.embedding_handler = MagicMock( + create_embedding=AsyncMock(return_value=SimpleNamespace(model_dump=lambda: {"ok": True})) + ) + assert (await api_server.create_reward(SimpleNamespace())).status_code == 200 + assert (await api_server.create_embedding(SimpleNamespace())).status_code == 200 + + # Weight update/clear + api_server.app.state.dynamic_load_weight = True + api_server.app.state.engine_client.update_model_weight.return_value = (False, "fail") + assert api_server.update_model_weight(MagicMock()).status_code == 404 + api_server.app.state.engine_client.update_model_weight.return_value = (True, "ok") + assert api_server.update_model_weight(MagicMock()).status_code == 200 + + api_server.app.state.engine_client.clear_load_weight.return_value = (False, "fail") + assert api_server.clear_load_weight(MagicMock()).status_code == 404 + api_server.app.state.engine_client.clear_load_weight.return_value = (True, "ok") + assert api_server.clear_load_weight(MagicMock()).status_code == 200 + + # Disabled branch + api_server.app.state.dynamic_load_weight = False + assert api_server.update_model_weight(MagicMock()).status_code == 404 + assert api_server.clear_load_weight(MagicMock()).status_code == 404 + + +@pytest.mark.asyncio +async def test_expert_and_stats_routes(): + args = _build_args() + with _patch_common_imports(args, engine_client_cls=_dummy_engine_client()): + api_server = _reload_api_server(args) + + api_server.app.state.engine_client = _dummy_engine_client()() + req = MagicMock() + req.json = AsyncMock(return_value={"a": 1}) + + assert (await api_server.rearrange_experts(req)).status_code == 201 + assert (await api_server.get_per_expert_tokens_stats(req)).status_code == 202 + assert (await api_server.check_redundant(req)).status_code == 203 + + +def test_launchers_and_controller(): + args = _build_args() + api_server = _reload_api_server(args) + + with patch("fastdeploy.entrypoints.openai.api_server.is_port_available", return_value=False): + with pytest.raises(Exception): + api_server.launch_api_server() + + with ( + patch("fastdeploy.entrypoints.openai.api_server.is_port_available", return_value=True), + patch("fastdeploy.entrypoints.openai.api_server.StandaloneApplication.run", side_effect=RuntimeError("fail")), + ): + api_server.launch_api_server() + + with patch("fastdeploy.entrypoints.openai.api_server.uvicorn.run") as uv_run: + api_server.run_metrics_server() + api_server.run_controller_server() + assert uv_run.call_count == 2 + + with ( + patch("fastdeploy.entrypoints.openai.api_server.is_port_available", return_value=True), + patch("fastdeploy.entrypoints.openai.api_server.run_metrics_server"), + ): + api_server.args.metrics_port = api_server.args.port + 1 + api_server.launch_metrics_server() + + with patch("fastdeploy.entrypoints.openai.api_server.is_port_available", return_value=False): + api_server.args.metrics_port = api_server.args.port + 2 + with pytest.raises(Exception): + api_server.launch_metrics_server() + + api_server.args.controller_port = -1 + api_server.launch_controller_server() + + api_server.args.controller_port = api_server.args.port + 5 + with patch("fastdeploy.entrypoints.openai.api_server.is_port_available", return_value=False): + with pytest.raises(Exception): + api_server.launch_controller_server() + + with ( + patch("fastdeploy.entrypoints.openai.api_server.is_port_available", return_value=True), + patch("fastdeploy.entrypoints.openai.api_server.run_controller_server"), + ): + api_server.launch_controller_server() + + +def test_worker_monitor_and_main(): + args = _build_args() + api_server = _reload_api_server(args) + + api_server.llm_engine = SimpleNamespace(worker_proc=SimpleNamespace(poll=lambda: 1, returncode=9)) + with patch("os.kill") as kill_mock: + api_server.launch_worker_monitor() + kill_mock.assert_called() + + api_server.args.local_data_parallel_id = 0 + with patch("fastdeploy.entrypoints.openai.api_server.load_engine", return_value=False): + api_server.main() + + api_server.args.local_data_parallel_id = 1 + with patch("fastdeploy.entrypoints.openai.api_server.load_data_service", return_value=False): + api_server.main() + + api_server.args.local_data_parallel_id = 0 + with ( + patch("fastdeploy.entrypoints.openai.api_server.load_engine", return_value=True), + patch("fastdeploy.entrypoints.openai.api_server.launch_metrics_server"), + patch("fastdeploy.entrypoints.openai.api_server.launch_worker_monitor"), + patch("fastdeploy.entrypoints.openai.api_server.launch_controller_server"), + patch("fastdeploy.entrypoints.openai.api_server.launch_api_server"), + ): + api_server.main() + + +@pytest.mark.asyncio +async def test_lifespan_and_health(): + args = _build_args() + with _patch_common_imports(args): + api_server = _reload_api_server(args) + engine_client = MagicMock() + engine_client.check_health.return_value = (False, "bad") + api_server.app.state.engine_client = engine_client + + assert api_server.health(MagicMock()).status_code == 404 + routes = await api_server.list_all_routes() + assert isinstance(routes, dict) + + +@pytest.mark.asyncio +async def test_list_models(): + args = _build_args() + with _patch_common_imports(args): + api_server = _reload_api_server(args) + api_server.app.state.dynamic_load_weight = False + + class FakeErrorResponse: + def model_dump(self): + return {"err": True} + + api_server.ErrorResponse = FakeErrorResponse + api_server.app.state.model_handler = MagicMock(list_models=AsyncMock(return_value=FakeErrorResponse())) + resp = await api_server.list_models() + assert resp.status_code == 200 + assert resp.body + + # dynamic_load_weight=True but workers_alive returns True (missing branch 442) + api_server.app.state.dynamic_load_weight = True + api_server.app.state.engine_client = MagicMock() + api_server.app.state.engine_client.is_workers_alive.return_value = (True, "ok") + + # Return ModelList instead of ErrorResponse (missing branch 449) + model_list = ModelList(data=[ModelInfo(id="test-model", object="model")]) + api_server.app.state.model_handler.list_models = AsyncMock(return_value=model_list) + resp2 = await api_server.list_models() + assert resp2.status_code == 200 + assert "data" in resp2.body.decode() if hasattr(resp2.body, "decode") else True + + +def test_control_scheduler(): + args = _build_args() + with _patch_common_imports(args): + api_server = _reload_api_server(args) + + with ( + patch("fastdeploy.entrypoints.openai.api_server.ErrorInfo", DummyErrorInfo), + patch("fastdeploy.entrypoints.openai.api_server.ErrorResponse", DummyErrorResponse), + ): + # Engine not loaded + api_server.llm_engine = None + req = SimpleNamespace(reset=False, load_shards_num=None, reallocate_shard=False) + assert api_server.control_scheduler(req).status_code == 500 + + # Without update_config + sched = SimpleNamespace() + api_server.llm_engine = SimpleNamespace(engine=SimpleNamespace(clear_data=MagicMock(), scheduler=sched)) + req = SimpleNamespace(reset=False, load_shards_num=1, reallocate_shard=True) + assert api_server.control_scheduler(req).status_code == 400 + + # Success path + scheduler = SimpleNamespace(update_config=MagicMock(), reset=MagicMock()) + engine = SimpleNamespace(clear_data=MagicMock(), scheduler=scheduler) + api_server.llm_engine = SimpleNamespace(engine=engine) + req = SimpleNamespace(reset=True, load_shards_num=2, reallocate_shard=True) + resp = api_server.control_scheduler(req) + + assert resp.status_code == 200 + engine.clear_data.assert_called_once() + scheduler.reset.assert_called_once() + scheduler.update_config.assert_called_once() + + # Only reset, no update_config (missing branch 681) + scheduler2 = SimpleNamespace(update_config=MagicMock(), reset=MagicMock()) + engine2 = SimpleNamespace(clear_data=MagicMock(), scheduler=scheduler2) + api_server.llm_engine = SimpleNamespace(engine=engine2) + req2 = SimpleNamespace(reset=True, load_shards_num=None, reallocate_shard=False) + resp2 = api_server.control_scheduler(req2) + + assert resp2.status_code == 200 + engine2.clear_data.assert_called_once() + scheduler2.reset.assert_called_once() + scheduler2.update_config.assert_not_called() + + +def test_config_info(): + args = _build_args() + with _patch_common_imports(args): + api_server = _reload_api_server(args) + api_server.llm_engine = None + assert api_server.config_info().status_code == 500