Files
FastDeploy/tests/entrypoints/test_vllm_run_engine.py
2025-12-04 10:38:51 +08:00

195 lines
7.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
from fastdeploy.worker.output import Logprob, LogprobsTensors
class DummyModelConfig:
def __init__(self, max_logprobs=10, ori_vocab_size=50, enable_logprob=True):
self.max_logprobs = max_logprobs
self.ori_vocab_size = ori_vocab_size
self.enable_logprob = enable_logprob
class DummyCacheConfig:
def __init__(self, enable_prefix_caching=False):
self.enable_prefix_caching = enable_prefix_caching
class DummyLLMEngineConfig:
def __init__(self, model_config=None, cache_config=None):
self.model_config = model_config or DummyModelConfig()
self.cache_config = cache_config or DummyCacheConfig()
class DummyLLMEngine:
def __init__(self, model_config=None, cache_config=None):
self.cfg = DummyLLMEngineConfig(model_config, cache_config)
self.data_processor = MagicMock()
# Mock tokenizer with sp_model attribute
self.data_processor.tokenizer = MagicMock()
self.data_processor.tokenizer.sp_model = MagicMock()
self.data_processor.tokenizer.sp_model.__len__ = MagicMock(return_value=100)
self.data_processor.tokenizer.vocab = MagicMock()
self.data_processor.tokenizer.vocab.__len__ = MagicMock(return_value=100)
self.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}"
self.add_requests = MagicMock()
@pytest.fixture
def mock_llm():
llm = LLM.__new__(LLM)
llm.llm_engine = DummyLLMEngine()
return llm
@pytest.fixture
def mock_llm_with_prefix_caching():
llm = LLM.__new__(LLM)
llm.llm_engine = DummyLLMEngine(cache_config=DummyCacheConfig(enable_prefix_caching=True))
return llm
def test_prompt_logprobs_not_supported_with_stream(mock_llm):
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(prompt_logprobs=5)
with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"):
mock_llm._add_request(["hi"], sampling, stream=True)
def test_prompt_logprobs_not_supported_with_prefix_caching(mock_llm_with_prefix_caching):
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(prompt_logprobs=5)
with pytest.raises(ValueError, match="prompt_logprobs is not supported with prefix caching enabled"):
mock_llm_with_prefix_caching._add_request(["hi"], sampling)
def test_num_logprobs_exceeds_max(mock_llm):
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to allow logprobs > 20
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(logprobs=20)
with pytest.raises(ValueError, match="Number of logprobs requested"):
mock_llm._add_request(["hi"], sampling)
def test_max_logprobs_exceeds_vocab_size(mock_llm):
# Test case where max_logprobs > ori_vocab_size
mock_llm.llm_engine.cfg.model_config.max_logprobs = 150 # > vocab size (100)
with pytest.raises(ValueError, match="max_logprobs \\(150\\) exceeds vocabulary size \\(100\\)"):
mock_llm._add_request(["hi"], SamplingParams())
def test_max_logprobs_less_than_minus_one(mock_llm):
# Test case where max_logprobs < -1
mock_llm.llm_engine.cfg.model_config.max_logprobs = -2
with pytest.raises(ValueError, match="max_logprobs \\(-2\\) can't be less than -1"):
mock_llm._add_request(["hi"], SamplingParams())
def test_logprobs_minus_one_uses_vocab_size(mock_llm):
# Test that logprobs=-1 uses vocab size
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 # Allow unlimited
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
def test_num_prompt_logprobs_exceeds_max(mock_llm):
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(prompt_logprobs=20)
with pytest.raises(ValueError, match="Number of logprobs requested"):
mock_llm._add_request(["hi"], sampling)
def test_logprobs_equal_to_minus_one_uses_ori_vocab_size(mock_llm):
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to allow logprobs=-1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
# Get the first argument (tasks) which should be a dict
call_args = mock_llm.llm_engine.add_requests.call_args
tasks = call_args[0][0] # First positional argument
assert isinstance(tasks, dict)
assert "prompt" in tasks
assert "request_id" in tasks
def test_prompt_logprobs_equal_to_minus_one(mock_llm):
# Set FD_USE_GET_SAVE_OUTPUT_V1=1 to enable prompt_logprobs support and allow -1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(prompt_logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
def test_dynamic_vocab_size_from_sp_model(mock_llm):
# Test that ori_vocab_size is dynamically obtained from sp_model
mock_llm.llm_engine.data_processor.tokenizer.sp_model.__len__.return_value = 200
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(logprobs=-1)
mock_llm._add_request(["hi"], sampling)
# Should use the dynamic vocab size (200)
mock_llm.llm_engine.add_requests.assert_called_once()
def test_dynamic_vocab_size_from_vocab_fallback(mock_llm):
# Test fallback to vocab when sp_model is not available
del mock_llm.llm_engine.data_processor.tokenizer.sp_model
mock_llm.llm_engine.data_processor.tokenizer.vocab.__len__.return_value = 300
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
sampling = SamplingParams(logprobs=-1)
mock_llm._add_request(["hi"], sampling)
# Should use the vocab size (300)
mock_llm.llm_engine.add_requests.assert_called_once()
def test_build_prompt_logprobs_basic(mock_llm):
# 构造 2 个 token每个 token 对应 3 个 logprob 值
token_ids = np.array([[1, 2, 3], [4, 5, 6]])
logprobs = np.array([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]])
ranks = np.array([1, 2])
tensors = LogprobsTensors(token_ids, logprobs, ranks)
result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=2)
# 检查结果格式
assert isinstance(result, list)
assert len(result) == 3
for pos_dict in result:
if pos_dict is not None:
assert isinstance(pos_dict, dict)
for logprob_obj in pos_dict.values():
assert isinstance(logprob_obj, Logprob)
assert logprob_obj.decoded_token.startswith("TOKEN_")
def test_build_prompt_logprobs_handles_minus_one(mock_llm):
token_ids = np.array([[7, 8]])
logprobs = np.array([[-0.9, -1.0]])
ranks = np.array([1])
tensors = LogprobsTensors(token_ids, logprobs, ranks)
result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=-1)
assert isinstance(result, list)
assert len(result) == 2
pos_dict = result[1]
assert 7 in pos_dict
assert pos_dict[7].decoded_token == "TOKEN_7"