[Bugfix] Fix uninitialized decoded_token and add corresponding unit test. (#3195)

This commit is contained in:
SunLei
2025-08-04 19:23:58 +08:00
committed by GitHub
parent 01d7586661
commit 68bc1d12c0
2 changed files with 84 additions and 1 deletions

View File

@@ -289,6 +289,10 @@ class LLM:
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking) self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
return req_ids return req_ids
def _decode_token(self, token_id: int) -> str:
"""Decodes a single token ID into its string representation."""
return self.llm_engine.data_processor.process_logprob_response([token_id], clean_up_tokenization_spaces=False)
def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]: def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]:
""" """
Constructs a list of dictionaries mapping token IDs to Logprob objects, Constructs a list of dictionaries mapping token IDs to Logprob objects,
@@ -322,8 +326,9 @@ class LLM:
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs) sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
result = [] result = []
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs): for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
logprob_dict = { logprob_dict = {
token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=None) token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=self._decode_token(token_id))
for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs)) for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs))
} }
result.append(logprob_dict) result.append(logprob_dict)

View File

@@ -0,0 +1,78 @@
import unittest
from unittest.mock import MagicMock, patch
from fastdeploy.entrypoints.llm import LLM
from fastdeploy.worker.output import Logprob, LogprobsLists
def get_patch_path(cls, method="__init__"):
return f"{cls.__module__}.{cls.__qualname__}.{method}"
class TestBuildSampleLogprobs(unittest.TestCase):
def setUp(self):
"""
Set up the test environment by creating an instance of the LLM class using Mock.
"""
patch_llm = get_patch_path(LLM)
with patch(patch_llm, return_value=None):
self.llm = LLM()
# mock d data_processor
self.llm.llm_engine = MagicMock()
self.llm.llm_engine.data_processor.process_logprob_response.side_effect = (
lambda ids, **kwargs: f"token_{ids[0]}"
)
def test_build_sample_logprobs_basic(self):
"""
Test case for building sample logprobs when `topk_logprobs` is valid.
"""
logprob_token_ids = [[100, 101, 102]]
logprobs = [[-0.1, -0.5, -1.0]]
sampled_token_ranks = [0]
logprobs_lists = LogprobsLists(
logprob_token_ids=logprob_token_ids, logprobs=logprobs, sampled_token_ranks=sampled_token_ranks
)
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
expected = [
{
101: Logprob(logprob=-0.5, rank=1, decoded_token="token_101"),
102: Logprob(logprob=-1.0, rank=2, decoded_token="token_102"),
}
]
self.assertEqual(result, expected)
def test_build_sample_logprobs_empty_input(self):
"""
Test case where `logprob_token_ids` is empty.
"""
logprobs_lists = MagicMock(spec=LogprobsLists)
logprobs_lists.logprob_token_ids = []
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
self.assertIsNone(result)
def test_build_sample_logprobs_invalid_topk(self):
"""
Test case where `topk` value exceeds length of first element in `logprob_token_ids`.
"""
logprobs_lists = MagicMock(spec=LogprobsLists)
logprobs_lists.logprob_token_ids = [[100]]
result = self.llm._build_sample_logprobs(logprobs_lists, topk_logprobs=2)
self.assertIsNone(result)
def test_decode_token(self):
"""
Test case for decoding a single token ID.
"""
token_id = 123
decoded = self.llm._decode_token(token_id)
self.assertEqual(decoded, "token_123")
if __name__ == "__main__":
unittest.main()