Files
FastDeploy/tests/entrypoints/test_vllm_run_engine.py
qwes5s5 a2d06118e1 [Logprobs]Support prompt_logprobs and max_logprobs (#4897)
* add prompt logprobs

* trigger ci

* fix unitest

* Update fastdeploy/config.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/entrypoints/llm.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/engine/sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/engine/test_sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/engine/test_sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix max_logprobs

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-12 19:29:48 +08:00

101 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from unittest.mock import MagicMock
import numpy as np
import pytest
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
from fastdeploy.worker.output import Logprob, LogprobsTensors
class DummyModelConfig:
def __init__(self, max_logprobs=10, ori_vocab_size=50):
self.max_logprobs = max_logprobs
self.ori_vocab_size = ori_vocab_size
@pytest.fixture
def mock_llm():
llm = LLM.__new__(LLM)
llm.llm_engine = MagicMock()
llm.llm_engine.add_requests = MagicMock()
llm.llm_engine.cfg.model_config = DummyModelConfig(max_logprobs=10, ori_vocab_size=100)
# Mock the data_processor.process_logprob_response method to return proper strings
llm.llm_engine.data_processor = MagicMock()
llm.llm_engine.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}"
return llm
def test_prompt_logprobs_not_supported_with_stream(mock_llm):
sampling = SamplingParams(prompt_logprobs=5)
with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"):
mock_llm._add_request(["hi"], sampling, stream=True)
def test_num_logprobs_exceeds_max(mock_llm):
sampling = SamplingParams(logprobs=20)
with pytest.raises(ValueError, match="Number of logprobs requested"):
mock_llm._add_request(["hi"], sampling)
def test_num_prompt_logprobs_exceeds_max(mock_llm):
sampling = SamplingParams(prompt_logprobs=20)
with pytest.raises(ValueError, match="Number of logprobs requested"):
mock_llm._add_request(["hi"], sampling)
def test_logprobs_equal_to_minus_one_uses_ori_vocab_size(mock_llm):
sampling = SamplingParams(logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 30
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
# Get the first argument (tasks) which should be a dict
call_args = mock_llm.llm_engine.add_requests.call_args
tasks = call_args[0][0] # First positional argument
assert isinstance(tasks, dict)
assert "prompt" in tasks
assert "request_id" in tasks
def test_prompt_logprobs_equal_to_minus_one(mock_llm):
sampling = SamplingParams(prompt_logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 25
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
def test_build_prompt_logprobs_basic(mock_llm):
# 构造 2 个 token每个 token 对应 3 个 logprob 值
token_ids = np.array([[1, 2, 3], [4, 5, 6]])
logprobs = np.array([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]])
ranks = np.array([1, 2])
tensors = LogprobsTensors(token_ids, logprobs, ranks)
result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=2)
# 检查结果格式
assert isinstance(result, list)
assert len(result) == 2
for pos_dict in result:
assert isinstance(pos_dict, dict)
for logprob_obj in pos_dict.values():
assert isinstance(logprob_obj, Logprob)
assert logprob_obj.decoded_token.startswith("TOKEN_")
def test_build_prompt_logprobs_handles_minus_one(mock_llm):
token_ids = np.array([[7, 8]])
logprobs = np.array([[-0.9, -1.0]])
ranks = np.array([1])
tensors = LogprobsTensors(token_ids, logprobs, ranks)
result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=-1)
assert isinstance(result, list)
assert len(result) == 1
pos_dict = result[0]
assert 7 in pos_dict
assert pos_dict[7].decoded_token == "TOKEN_7"