[Logprobs]Support prompt_logprobs and max_logprobs (#4897)

* add prompt logprobs

* trigger ci

* fix unitest

* Update fastdeploy/config.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/entrypoints/llm.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update fastdeploy/engine/sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/engine/test_sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update tests/engine/test_sampling_params.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix max_logprobs

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
qwes5s5
2025-11-12 19:29:48 +08:00
committed by GitHub
parent da7863ae85
commit a2d06118e1
9 changed files with 623 additions and 9 deletions

View File

@@ -227,8 +227,8 @@ class ModelConfig:
self.think_end_id = args.get("think_end_id", -1)
self.im_patch_id = args.get("image_patch_id", -1)
self.line_break_id = args.get("line_break_id", -1)
if self.max_logprobs == -1 and hasattr(self, "vocab_size"):
self.max_logprobs = self.vocab_size
if self.max_logprobs < -1:
raise ValueError(" The possible values for max_logprobs can't be less than -1 ")
self._post_init()

View File

@@ -29,7 +29,12 @@ from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.openai.protocol import ToolCall
from fastdeploy.utils import data_processor_logger
from fastdeploy.worker.output import LogprobsLists, SampleLogprobs
from fastdeploy.worker.output import (
LogprobsLists,
LogprobsTensors,
PromptLogprobs,
SampleLogprobs,
)
class RequestStatus(Enum):
@@ -463,6 +468,8 @@ class RequestOutput:
request_id: str,
prompt: Optional[str] = None,
prompt_token_ids: Optional[list[int]] = None,
prompt_logprobs: Optional[PromptLogprobs] = None,
prompt_logprobs_tensors: Optional[LogprobsTensors] = None,
output_type: Optional[int] = 3,
outputs: CompletionOutput = None,
finished: bool = False,
@@ -476,6 +483,8 @@ class RequestOutput:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.prompt_logprobs = prompt_logprobs
self.prompt_logprobs_tensors = prompt_logprobs_tensors
self.output_type = output_type
self.outputs = outputs
self.finished = finished
@@ -521,6 +530,7 @@ class RequestOutput:
f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"output_type={self.output_type}, "
f"outputs={self.outputs}, "
f"finished={self.finished}, "
@@ -546,6 +556,7 @@ class RequestOutput:
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": self.prompt_token_ids,
"prompt_logprobs": self.prompt_logprobs,
"output_type": self.output_type,
"outputs": None if self.outputs is None else self.outputs.to_dict(),
"metrics": None if self.metrics is None else self.metrics.to_dict(),

View File

@@ -16,6 +16,7 @@
from __future__ import annotations
import os
import random
from dataclasses import dataclass, fields
from enum import Enum
@@ -206,10 +207,12 @@ class SamplingParams:
raise ValueError(
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
)
if self.logprobs is not None and self.logprobs < 0:
raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
if self.logprobs is not None and self.logprobs > 20:
if self.logprobs is not None and self.logprobs < -1:
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.")
if not 0 <= self.seed <= 922337203685477580:
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")

View File

@@ -16,11 +16,13 @@
from __future__ import annotations
import itertools
import logging
import threading
import time
import traceback
import uuid
from collections.abc import Iterable
from typing import Any, Optional, Union
from pydantic import ValidationError
@@ -37,13 +39,20 @@ from fastdeploy.utils import (
llm_logger,
retrive_model_from_server,
)
from fastdeploy.worker.output import Logprob, LogprobsLists
from fastdeploy.worker.output import (
Logprob,
LogprobsLists,
LogprobsTensors,
PromptLogprobs,
)
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
if isinstance(handler, logging.StreamHandler):
root_logger.removeHandler(handler)
NONES = itertools.repeat(None)
class LLM:
"""
@@ -189,12 +198,17 @@ class LLM:
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
num_prompt_logprobs = (
sampling_params[0].prompt_logprobs if sampling_params_len > 1 else sampling_params.prompt_logprobs
)
# get output
if stream:
return self._run_engine_stream(req_ids, prompts, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
else:
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
outputs = self._run_engine(
req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs, num_prompt_logprobs=num_prompt_logprobs
)
for i in range(len(outputs)):
outputs[i].prompt = prompts[i]
return outputs
@@ -321,6 +335,27 @@ class LLM:
current_sampling_params = sampling_params[i]
else:
current_sampling_params = sampling_params
if kwargs.get("stream") and current_sampling_params.prompt_logprobs is not None:
raise ValueError("prompt_logprobs is not supported with streaming.")
max_logprobs = self.llm_engine.cfg.model_config.max_logprobs
if max_logprobs == -1:
max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
if current_sampling_params.logprobs is not None:
num_logprobs = current_sampling_params.logprobs
if num_logprobs == -1:
num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
if num_logprobs > max_logprobs:
raise ValueError(
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
)
if current_sampling_params.prompt_logprobs is not None:
num_prompt_logprobs = current_sampling_params.prompt_logprobs
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
if num_prompt_logprobs > max_logprobs:
raise ValueError(
f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
)
if current_sampling_params.guided_decoding is not None:
guided_decoding_dict = current_sampling_params.guided_decoding.to_dict()
tasks.update(guided_decoding_dict)
@@ -377,7 +412,93 @@ class LLM:
except Exception as e:
llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}, {str(traceback.format_exc())}")
def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None):
def _build_prompt_logprobs(
self,
prompt_logprobs_tensors: LogprobsTensors,
num_prompt_logprobs: int,
):
"""Update with prompt logprobs from worker.
Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.
"""
token_ids, logprobs, ranks = prompt_logprobs_tensors
# Detokenize non-incrementally.
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
decoded_tokens = [self._decode_token(token_id) for token_id in token_ids.flatten().tolist()]
# Recover shapes.
num_prompt_tokens, num_logprobs = logprobs.shape
# Pythonize the paddle tensors.
prompt_token_ranks = ranks.tolist()
prompt_logprobs = logprobs.tolist()
token_ids = token_ids.tolist()
result: Optional[PromptLogprobs] = []
# Make Logprob for each position.
for pos in range(num_prompt_tokens):
# Handle flattening.
offset = pos * num_logprobs
offset_end = offset + num_logprobs
decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
# Update with the Logprob dictionary for this pos.
result.append(
self._make_logprob_dict(
prompt_logprobs[pos],
token_ids[pos],
decoded_tokens_for_pos,
prompt_token_ranks[pos],
num_prompt_logprobs,
)
)
return result
@staticmethod
def _make_logprob_dict(
logprobs: list[float],
logprob_token_ids: list[int],
decoded_tokens: Iterable[str | None],
rank: int,
num_logprobs: int,
) -> dict[int, Logprob]:
"""Make a Logprob dictionary for a position.
Args:
logprobs: list of log probabilities
logprob_token_ids: list of top token ids
decoded_tokens: list of decoded top tokens
rank: rank of the sampled token
num_logprobs: number of logprobs requested
by the user (in addition to sampled logprob)
Returns:
dict[token id, Logprob]
"""
if num_logprobs == -1:
num_logprobs = len(logprobs)
# We do not need a special case for the sampled token
# being in the topk, since inserting duplicated data
# into a dictionary twice is the same as doing it once.
topk_ranks = range(1, num_logprobs + 1)
ranks = itertools.chain((rank,), topk_ranks)
return {
token_id: Logprob(
logprob=logprob,
rank=rank,
decoded_token=token,
)
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
}
def _run_engine(
self,
req_ids: list[str],
use_tqdm: bool,
topk_logprobs: Optional[int] = None,
num_prompt_logprobs: Optional[int] = None,
):
"""
运行引擎,并返回结果列表。
@@ -422,9 +543,17 @@ class LLM:
# filter logprobs
if result.outputs.top_logprobs and topk_logprobs:
if topk_logprobs == -1:
topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
result.outputs.logprobs = self._build_sample_logprobs(
result.outputs.top_logprobs, topk_logprobs
)
if result.prompt_logprobs_tensors and num_prompt_logprobs:
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
result.prompt_logprobs = self._build_prompt_logprobs(
result.prompt_logprobs_tensors, num_prompt_logprobs
)
output[pos] = result
finished.append(i)

View File

@@ -290,6 +290,19 @@ class TokenProcessor:
finished=False,
metrics=metrics,
)
if self.use_logprobs:
if getattr(stream_data, "logprobs", None) is not None:
try:
logprobs_list: LogprobsLists = stream_data.logprobs.tolists()
result.outputs.logprob = float(logprobs_list.logprobs[0][0])
result.outputs.top_logprobs = logprobs_list
except Exception as e:
llm_logger.warning(f"Failed to parse logprobs from StreamTransferData: {e}")
if getattr(stream_data, "prompt_logprobs", None) is not None:
try:
result.prompt_logprobs_tensors = stream_data.prompt_logprobs
except Exception as e:
llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}")
if self.tokens_counter[task_id] == 0:
if task.messages is not None:
result.prompt = task.messages

View File

@@ -30,6 +30,7 @@ class Logprob(NamedTuple):
decoded_token: Optional[str] = None
PromptLogprobs = list[dict[int, Logprob] | None]
# [{token_id, logprob}] for tokens sampled from the top-k
SampleLogprobs = list[dict[int, Logprob]]

View File

@@ -0,0 +1,222 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import unittest
from unittest.mock import patch
from fastdeploy.engine.sampling_params import SamplingParams
class TestSamplingParamsVerification(unittest.TestCase):
"""Test case for SamplingParams _verify_args method"""
def test_logprobs_valid_values(self):
"""Test valid logprobs values"""
# Test None value (should pass)
params = SamplingParams(logprobs=None)
params._verify_args() # Should not raise
# Test -1 value (should pass)
params = SamplingParams(logprobs=-1)
params._verify_args() # Should not raise
# Test 0 value (should pass)
params = SamplingParams(logprobs=0)
params._verify_args() # Should not raise
# Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0")
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(logprobs=20)
params._verify_args() # Should not raise
def test_logprobs_invalid_less_than_minus_one(self):
"""Test logprobs less than -1 should raise ValueError"""
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-2)
params._verify_args()
self.assertIn("logprobs must be greater than -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
def test_logprobs_greater_than_20_with_v1_disabled(self):
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=21)
params._verify_args()
self.assertEqual("Invalid value for 'top_logprobs': must be less than or equal to 20.", str(cm.exception))
def test_logprobs_greater_than_20_with_v1_enabled(self):
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
# Should not raise when v1 is enabled
params = SamplingParams(logprobs=21)
params._verify_args() # Should not raise
# Test even larger values when v1 is enabled
params = SamplingParams(logprobs=100)
params._verify_args() # Should not raise
def test_prompt_logprobs_valid_values(self):
"""Test valid prompt_logprobs values"""
# Test None value (should pass)
params = SamplingParams(prompt_logprobs=None)
params._verify_args() # Should not raise
# Test -1 value (should pass)
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should not raise
# Test 0 value (should pass)
params = SamplingParams(prompt_logprobs=0)
params._verify_args() # Should not raise
# Test positive values (should pass)
params = SamplingParams(prompt_logprobs=10)
params._verify_args() # Should not raise
def test_prompt_logprobs_invalid_less_than_minus_one(self):
"""Test prompt_logprobs less than -1 should raise ValueError"""
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
self.assertIn("prompt_logprobs must be greater than or equal to -1", str(cm.exception))
self.assertIn("got -2", str(cm.exception))
def test_combined_logprobs_and_prompt_logprobs(self):
"""Test both logprobs and prompt_logprobs together"""
# Test valid combination
params = SamplingParams(logprobs=5, prompt_logprobs=3)
params._verify_args() # Should not raise
# Test invalid logprobs with valid prompt_logprobs
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
params._verify_args()
# Test valid logprobs with invalid prompt_logprobs
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
params._verify_args()
def test_logprobs_boundary_values(self):
"""Test boundary values for logprobs"""
# Test just below limit with v1 disabled
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
params = SamplingParams(logprobs=20)
params._verify_args() # Should pass
# Test just above limit with v1 disabled
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=21)
params._verify_args()
def test_prompt_logprobs_boundary_values(self):
"""Test boundary values for prompt_logprobs"""
# Test boundary value -1 (should pass)
params = SamplingParams(prompt_logprobs=-1)
params._verify_args() # Should pass
# Test boundary value just below -1 (should fail)
with self.assertRaises(ValueError):
params = SamplingParams(prompt_logprobs=-2)
params._verify_args()
def test_environment_variable_handling(self):
"""Test different environment variable values"""
# Test FD_USE_GET_SAVE_OUTPUT_V1 = "0" (default behavior)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=21)
params._verify_args()
# Test FD_USE_GET_SAVE_OUTPUT_V1 = "1" (relaxed behavior)
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
params = SamplingParams(logprobs=21)
params._verify_args() # Should pass
# Test FD_USE_GET_SAVE_OUTPUT_V1 not set (default to "0")
if "FD_USE_GET_SAVE_OUTPUT_V1" in os.environ:
original_value = os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
del os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
else:
original_value = None
try:
with self.assertRaises(ValueError):
params = SamplingParams(logprobs=21)
params._verify_args()
finally:
if original_value is not None:
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value
def test_error_message_formatting(self):
"""Test that error messages are properly formatted"""
# Test logprobs error message
with self.assertRaises(ValueError) as cm:
params = SamplingParams(logprobs=-5)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("logprobs must be greater than -1", error_msg)
self.assertIn("got -5", error_msg)
# Test prompt_logprobs error message
with self.assertRaises(ValueError) as cm:
params = SamplingParams(prompt_logprobs=-10)
params._verify_args()
error_msg = str(cm.exception)
self.assertIn("prompt_logprobs must be greater than or equal to -1", error_msg)
self.assertIn("got -10", error_msg)
def test_post_init_calls_verify_args(self):
"""Test that __post_init__ calls _verify_args"""
# This should call _verify_args internally
params = SamplingParams(logprobs=5, prompt_logprobs=3)
# The params should be successfully created without errors
self.assertEqual(params.logprobs, 5)
self.assertEqual(params.prompt_logprobs, 3)
# Test that invalid values are caught during initialization
with self.assertRaises(ValueError):
SamplingParams(logprobs=-2)
with self.assertRaises(ValueError):
SamplingParams(prompt_logprobs=-2)
def test_logprobs_with_other_parameters(self):
"""Test logprobs validation with other sampling parameters"""
# Test with temperature
params = SamplingParams(logprobs=5, temperature=0.8)
params._verify_args() # Should pass
# Test with top_p
params = SamplingParams(logprobs=5, top_p=0.9)
params._verify_args() # Should pass
# Test with all parameters
params = SamplingParams(logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100)
params._verify_args() # Should pass
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,100 @@
from unittest.mock import MagicMock
import numpy as np
import pytest
from fastdeploy.engine.sampling_params import SamplingParams
from fastdeploy.entrypoints.llm import LLM
from fastdeploy.worker.output import Logprob, LogprobsTensors
class DummyModelConfig:
def __init__(self, max_logprobs=10, ori_vocab_size=50):
self.max_logprobs = max_logprobs
self.ori_vocab_size = ori_vocab_size
@pytest.fixture
def mock_llm():
llm = LLM.__new__(LLM)
llm.llm_engine = MagicMock()
llm.llm_engine.add_requests = MagicMock()
llm.llm_engine.cfg.model_config = DummyModelConfig(max_logprobs=10, ori_vocab_size=100)
# Mock the data_processor.process_logprob_response method to return proper strings
llm.llm_engine.data_processor = MagicMock()
llm.llm_engine.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}"
return llm
def test_prompt_logprobs_not_supported_with_stream(mock_llm):
sampling = SamplingParams(prompt_logprobs=5)
with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"):
mock_llm._add_request(["hi"], sampling, stream=True)
def test_num_logprobs_exceeds_max(mock_llm):
sampling = SamplingParams(logprobs=20)
with pytest.raises(ValueError, match="Number of logprobs requested"):
mock_llm._add_request(["hi"], sampling)
def test_num_prompt_logprobs_exceeds_max(mock_llm):
sampling = SamplingParams(prompt_logprobs=20)
with pytest.raises(ValueError, match="Number of logprobs requested"):
mock_llm._add_request(["hi"], sampling)
def test_logprobs_equal_to_minus_one_uses_ori_vocab_size(mock_llm):
sampling = SamplingParams(logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 30
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
# Get the first argument (tasks) which should be a dict
call_args = mock_llm.llm_engine.add_requests.call_args
tasks = call_args[0][0] # First positional argument
assert isinstance(tasks, dict)
assert "prompt" in tasks
assert "request_id" in tasks
def test_prompt_logprobs_equal_to_minus_one(mock_llm):
sampling = SamplingParams(prompt_logprobs=-1)
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 25
mock_llm._add_request(["hi"], sampling)
mock_llm.llm_engine.add_requests.assert_called_once()
def test_build_prompt_logprobs_basic(mock_llm):
# 构造 2 个 token每个 token 对应 3 个 logprob 值
token_ids = np.array([[1, 2, 3], [4, 5, 6]])
logprobs = np.array([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]])
ranks = np.array([1, 2])
tensors = LogprobsTensors(token_ids, logprobs, ranks)
result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=2)
# 检查结果格式
assert isinstance(result, list)
assert len(result) == 2
for pos_dict in result:
assert isinstance(pos_dict, dict)
for logprob_obj in pos_dict.values():
assert isinstance(logprob_obj, Logprob)
assert logprob_obj.decoded_token.startswith("TOKEN_")
def test_build_prompt_logprobs_handles_minus_one(mock_llm):
token_ids = np.array([[7, 8]])
logprobs = np.array([[-0.9, -1.0]])
ranks = np.array([1])
tensors = LogprobsTensors(token_ids, logprobs, ranks)
result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=-1)
assert isinstance(result, list)
assert len(result) == 1
pos_dict = result[0]
assert 7 in pos_dict
assert pos_dict[7].decoded_token == "TOKEN_7"

View File

@@ -0,0 +1,135 @@
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
from fastdeploy.engine.request import CompletionOutput, RequestOutput
from fastdeploy.output.token_processor import TokenProcessor
from fastdeploy.worker.output import LogprobsLists
class TestTokenProcessorLogprobs(unittest.TestCase):
def setUp(self):
self.cfg = MagicMock()
self.cfg.model_config.enable_logprob = True
self.cfg.speculative_config.method = None
self.cfg.parallel_config.local_data_parallel_id = 0
self.cached_generated_tokens = MagicMock()
self.engine_worker_queue = MagicMock()
self.split_connector = MagicMock()
self.processor = TokenProcessor(
self.cfg, self.cached_generated_tokens, self.engine_worker_queue, self.split_connector
)
# Mock resource manager
self.processor.resource_manager = MagicMock()
self.processor.resource_manager.stop_flags = [False]
# Create a proper task mock with time attributes
self.task_mock = MagicMock()
self.task_mock.request_id = "test_request"
self.task_mock.pooling_params = None
self.task_mock.messages = None
self.task_mock.disaggregate_info = None
self.task_mock.eos_token_ids = [2]
self.task_mock.inference_start_time = 100.0 # Set a float value for time calculation
self.task_mock.arrival_time = 90.0
self.task_mock.preprocess_end_time = 95.0
self.task_mock.preprocess_start_time = 90.0
self.task_mock.schedule_start_time = 95.0
self.processor.resource_manager.tasks_list = [self.task_mock]
# Mock logger
self.processor.llm_logger = MagicMock()
# Mock metrics to avoid prometheus dependency issues
self.processor.main_process_metrics = MagicMock()
self.processor._recycle_resources = MagicMock()
# Mock the _process_per_token method to avoid prometheus issues
self.processor._process_per_token = MagicMock()
self.processor._process_per_token.return_value = RequestOutput(
request_id="test_request",
outputs=CompletionOutput(
index=0,
send_idx=0,
token_ids=[],
draft_token_ids=[],
),
finished=False,
metrics=MagicMock(),
)
def test_process_logprobs_success(self):
"""Test successful logprobs parsing"""
stream_data = MagicMock()
logprobs = MagicMock()
logprobs.tolists.return_value = LogprobsLists(
logprobs=[[0.5]], logprob_token_ids=[[1]], sampled_token_ranks=[0]
)
stream_data.logprobs = logprobs
stream_data.tokens = np.array([1])
stream_data.batch_id = 0
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 1)
self.processor.llm_logger.warning.assert_not_called()
def test_process_logprobs_failure(self):
"""Test failed logprobs parsing"""
stream_data = MagicMock()
stream_data.logprobs = MagicMock()
stream_data.logprobs.tolists.side_effect = Exception("Test error")
stream_data.tokens = np.array([1])
stream_data.batch_id = 0
with patch.object(self.processor.llm_logger, "warning"):
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 1)
self.assertIsNone(result[0].outputs.logprob)
def test_process_prompt_logprobs_success(self):
"""Test successful prompt_logprobs parsing"""
stream_data = MagicMock()
stream_data.logprobs = None
stream_data.prompt_logprobs = np.array([0.1, 0.2])
stream_data.tokens = np.array([1])
stream_data.batch_id = 0
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 1)
self.processor.llm_logger.warning.assert_not_called()
def test_process_prompt_logprobs_failure(self):
"""Test failed prompt_logprobs parsing"""
stream_data = MagicMock()
stream_data.logprobs = None
stream_data.prompt_logprobs = MagicMock()
stream_data.prompt_logprobs.tolist.side_effect = AttributeError("'NoneType' object has no attribute 'tolist'")
stream_data.tokens = np.array([1])
stream_data.batch_id = 0
with patch.object(self.processor.llm_logger, "warning"):
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 1)
self.assertIsNone(getattr(result[0], "prompt_logprobs_tensors", None))
def test_process_batch_with_stop_flag(self):
"""Test processing when stop flag is True"""
self.processor.resource_manager.stop_flags = [True]
stream_data = MagicMock()
stream_data.batch_id = 0
result = self.processor._process_batch_output_use_zmq([stream_data])
self.assertEqual(len(result), 0)
if __name__ == "__main__":
unittest.main()