diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 61d125b40..4e31f9213 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -227,8 +227,8 @@ class ModelConfig: self.think_end_id = args.get("think_end_id", -1) self.im_patch_id = args.get("image_patch_id", -1) self.line_break_id = args.get("line_break_id", -1) - if self.max_logprobs == -1 and hasattr(self, "vocab_size"): - self.max_logprobs = self.vocab_size + if self.max_logprobs < -1: + raise ValueError(" The possible values for max_logprobs can't be less than -1 ") self._post_init() diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 743a35fc5..776611560 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -29,7 +29,12 @@ from fastdeploy.engine.pooling_params import PoolingParams from fastdeploy.engine.sampling_params import SamplingParams from fastdeploy.entrypoints.openai.protocol import ToolCall from fastdeploy.utils import data_processor_logger -from fastdeploy.worker.output import LogprobsLists, SampleLogprobs +from fastdeploy.worker.output import ( + LogprobsLists, + LogprobsTensors, + PromptLogprobs, + SampleLogprobs, +) class RequestStatus(Enum): @@ -463,6 +468,8 @@ class RequestOutput: request_id: str, prompt: Optional[str] = None, prompt_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[PromptLogprobs] = None, + prompt_logprobs_tensors: Optional[LogprobsTensors] = None, output_type: Optional[int] = 3, outputs: CompletionOutput = None, finished: bool = False, @@ -476,6 +483,8 @@ class RequestOutput: self.request_id = request_id self.prompt = prompt self.prompt_token_ids = prompt_token_ids + self.prompt_logprobs = prompt_logprobs + self.prompt_logprobs_tensors = prompt_logprobs_tensors self.output_type = output_type self.outputs = outputs self.finished = finished @@ -521,6 +530,7 @@ class RequestOutput: f"RequestOutput(request_id={self.request_id}, " f"prompt={self.prompt!r}, " f"prompt_token_ids={self.prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " f"output_type={self.output_type}, " f"outputs={self.outputs}, " f"finished={self.finished}, " @@ -546,6 +556,7 @@ class RequestOutput: "request_id": self.request_id, "prompt": self.prompt, "prompt_token_ids": self.prompt_token_ids, + "prompt_logprobs": self.prompt_logprobs, "output_type": self.output_type, "outputs": None if self.outputs is None else self.outputs.to_dict(), "metrics": None if self.metrics is None else self.metrics.to_dict(), diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 60aa67964..0b0ae5f80 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -16,6 +16,7 @@ from __future__ import annotations +import os import random from dataclasses import dataclass, fields from enum import Enum @@ -206,10 +207,12 @@ class SamplingParams: raise ValueError( f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}." ) - if self.logprobs is not None and self.logprobs < 0: - raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.") - if self.logprobs is not None and self.logprobs > 20: + if self.logprobs is not None and self.logprobs < -1: + raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.") + if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0": raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.") + if self.prompt_logprobs is not None and self.prompt_logprobs < -1: + raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.") if not 0 <= self.seed <= 922337203685477580: raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.") diff --git a/fastdeploy/entrypoints/llm.py b/fastdeploy/entrypoints/llm.py index 430011d23..adf7d1cd7 100644 --- a/fastdeploy/entrypoints/llm.py +++ b/fastdeploy/entrypoints/llm.py @@ -16,11 +16,13 @@ from __future__ import annotations +import itertools import logging import threading import time import traceback import uuid +from collections.abc import Iterable from typing import Any, Optional, Union from pydantic import ValidationError @@ -37,13 +39,20 @@ from fastdeploy.utils import ( llm_logger, retrive_model_from_server, ) -from fastdeploy.worker.output import Logprob, LogprobsLists +from fastdeploy.worker.output import ( + Logprob, + LogprobsLists, + LogprobsTensors, + PromptLogprobs, +) root_logger = logging.getLogger() for handler in root_logger.handlers[:]: if isinstance(handler, logging.StreamHandler): root_logger.removeHandler(handler) +NONES = itertools.repeat(None) + class LLM: """ @@ -189,12 +198,17 @@ class LLM: req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params) topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs + num_prompt_logprobs = ( + sampling_params[0].prompt_logprobs if sampling_params_len > 1 else sampling_params.prompt_logprobs + ) # get output if stream: return self._run_engine_stream(req_ids, prompts, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs) else: - outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs) + outputs = self._run_engine( + req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs, num_prompt_logprobs=num_prompt_logprobs + ) for i in range(len(outputs)): outputs[i].prompt = prompts[i] return outputs @@ -321,6 +335,27 @@ class LLM: current_sampling_params = sampling_params[i] else: current_sampling_params = sampling_params + if kwargs.get("stream") and current_sampling_params.prompt_logprobs is not None: + raise ValueError("prompt_logprobs is not supported with streaming.") + max_logprobs = self.llm_engine.cfg.model_config.max_logprobs + if max_logprobs == -1: + max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size + if current_sampling_params.logprobs is not None: + num_logprobs = current_sampling_params.logprobs + if num_logprobs == -1: + num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size + if num_logprobs > max_logprobs: + raise ValueError( + f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})." + ) + if current_sampling_params.prompt_logprobs is not None: + num_prompt_logprobs = current_sampling_params.prompt_logprobs + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size + if num_prompt_logprobs > max_logprobs: + raise ValueError( + f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})." + ) if current_sampling_params.guided_decoding is not None: guided_decoding_dict = current_sampling_params.guided_decoding.to_dict() tasks.update(guided_decoding_dict) @@ -377,7 +412,93 @@ class LLM: except Exception as e: llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}, {str(traceback.format_exc())}") - def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None): + def _build_prompt_logprobs( + self, + prompt_logprobs_tensors: LogprobsTensors, + num_prompt_logprobs: int, + ): + """Update with prompt logprobs from worker. + Args: + prompt_logprobs_tensors: tuple containing the prompt logprobs + tensors. + """ + + token_ids, logprobs, ranks = prompt_logprobs_tensors + + # Detokenize non-incrementally. + # Output is flat: [num_tok, num_lps] -> [num_tok * num_lps] + decoded_tokens = [self._decode_token(token_id) for token_id in token_ids.flatten().tolist()] + + # Recover shapes. + num_prompt_tokens, num_logprobs = logprobs.shape + + # Pythonize the paddle tensors. + prompt_token_ranks = ranks.tolist() + prompt_logprobs = logprobs.tolist() + token_ids = token_ids.tolist() + result: Optional[PromptLogprobs] = [] + # Make Logprob for each position. + for pos in range(num_prompt_tokens): + # Handle flattening. + offset = pos * num_logprobs + offset_end = offset + num_logprobs + decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end] + + # Update with the Logprob dictionary for this pos. + result.append( + self._make_logprob_dict( + prompt_logprobs[pos], + token_ids[pos], + decoded_tokens_for_pos, + prompt_token_ranks[pos], + num_prompt_logprobs, + ) + ) + return result + + @staticmethod + def _make_logprob_dict( + logprobs: list[float], + logprob_token_ids: list[int], + decoded_tokens: Iterable[str | None], + rank: int, + num_logprobs: int, + ) -> dict[int, Logprob]: + """Make a Logprob dictionary for a position. + Args: + logprobs: list of log probabilities + logprob_token_ids: list of top token ids + decoded_tokens: list of decoded top tokens + rank: rank of the sampled token + num_logprobs: number of logprobs requested + by the user (in addition to sampled logprob) + Returns: + dict[token id, Logprob] + """ + if num_logprobs == -1: + num_logprobs = len(logprobs) + # We do not need a special case for the sampled token + # being in the topk, since inserting duplicated data + # into a dictionary twice is the same as doing it once. + topk_ranks = range(1, num_logprobs + 1) + ranks = itertools.chain((rank,), topk_ranks) + + return { + token_id: Logprob( + logprob=logprob, + rank=rank, + decoded_token=token, + ) + for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens) + } + + def _run_engine( + self, + req_ids: list[str], + use_tqdm: bool, + topk_logprobs: Optional[int] = None, + num_prompt_logprobs: Optional[int] = None, + ): """ 运行引擎,并返回结果列表。 @@ -422,9 +543,17 @@ class LLM: # filter logprobs if result.outputs.top_logprobs and topk_logprobs: + if topk_logprobs == -1: + topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size result.outputs.logprobs = self._build_sample_logprobs( result.outputs.top_logprobs, topk_logprobs ) + if result.prompt_logprobs_tensors and num_prompt_logprobs: + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size + result.prompt_logprobs = self._build_prompt_logprobs( + result.prompt_logprobs_tensors, num_prompt_logprobs + ) output[pos] = result finished.append(i) diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index f7b0f67b9..fa284190c 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -290,6 +290,19 @@ class TokenProcessor: finished=False, metrics=metrics, ) + if self.use_logprobs: + if getattr(stream_data, "logprobs", None) is not None: + try: + logprobs_list: LogprobsLists = stream_data.logprobs.tolists() + result.outputs.logprob = float(logprobs_list.logprobs[0][0]) + result.outputs.top_logprobs = logprobs_list + except Exception as e: + llm_logger.warning(f"Failed to parse logprobs from StreamTransferData: {e}") + if getattr(stream_data, "prompt_logprobs", None) is not None: + try: + result.prompt_logprobs_tensors = stream_data.prompt_logprobs + except Exception as e: + llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}") if self.tokens_counter[task_id] == 0: if task.messages is not None: result.prompt = task.messages diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 45ee9f906..a674c3606 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -30,6 +30,7 @@ class Logprob(NamedTuple): decoded_token: Optional[str] = None +PromptLogprobs = list[dict[int, Logprob] | None] # [{token_id, logprob}] for tokens sampled from the top-k SampleLogprobs = list[dict[int, Logprob]] diff --git a/tests/engine/test_sampling_params.py b/tests/engine/test_sampling_params.py new file mode 100644 index 000000000..e8210c35e --- /dev/null +++ b/tests/engine/test_sampling_params.py @@ -0,0 +1,222 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os +import unittest +from unittest.mock import patch + +from fastdeploy.engine.sampling_params import SamplingParams + + +class TestSamplingParamsVerification(unittest.TestCase): + """Test case for SamplingParams _verify_args method""" + + def test_logprobs_valid_values(self): + """Test valid logprobs values""" + # Test None value (should pass) + params = SamplingParams(logprobs=None) + params._verify_args() # Should not raise + + # Test -1 value (should pass) + params = SamplingParams(logprobs=-1) + params._verify_args() # Should not raise + + # Test 0 value (should pass) + params = SamplingParams(logprobs=0) + params._verify_args() # Should not raise + + # Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0") + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + params = SamplingParams(logprobs=20) + params._verify_args() # Should not raise + + def test_logprobs_invalid_less_than_minus_one(self): + """Test logprobs less than -1 should raise ValueError""" + with self.assertRaises(ValueError) as cm: + params = SamplingParams(logprobs=-2) + params._verify_args() + + self.assertIn("logprobs must be greater than -1", str(cm.exception)) + self.assertIn("got -2", str(cm.exception)) + + def test_logprobs_greater_than_20_with_v1_disabled(self): + """Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled""" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + with self.assertRaises(ValueError) as cm: + params = SamplingParams(logprobs=21) + params._verify_args() + + self.assertEqual("Invalid value for 'top_logprobs': must be less than or equal to 20.", str(cm.exception)) + + def test_logprobs_greater_than_20_with_v1_enabled(self): + """Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled""" + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + # Should not raise when v1 is enabled + params = SamplingParams(logprobs=21) + params._verify_args() # Should not raise + + # Test even larger values when v1 is enabled + params = SamplingParams(logprobs=100) + params._verify_args() # Should not raise + + def test_prompt_logprobs_valid_values(self): + """Test valid prompt_logprobs values""" + # Test None value (should pass) + params = SamplingParams(prompt_logprobs=None) + params._verify_args() # Should not raise + + # Test -1 value (should pass) + params = SamplingParams(prompt_logprobs=-1) + params._verify_args() # Should not raise + + # Test 0 value (should pass) + params = SamplingParams(prompt_logprobs=0) + params._verify_args() # Should not raise + + # Test positive values (should pass) + params = SamplingParams(prompt_logprobs=10) + params._verify_args() # Should not raise + + def test_prompt_logprobs_invalid_less_than_minus_one(self): + """Test prompt_logprobs less than -1 should raise ValueError""" + with self.assertRaises(ValueError) as cm: + params = SamplingParams(prompt_logprobs=-2) + params._verify_args() + + self.assertIn("prompt_logprobs must be greater than or equal to -1", str(cm.exception)) + self.assertIn("got -2", str(cm.exception)) + + def test_combined_logprobs_and_prompt_logprobs(self): + """Test both logprobs and prompt_logprobs together""" + # Test valid combination + params = SamplingParams(logprobs=5, prompt_logprobs=3) + params._verify_args() # Should not raise + + # Test invalid logprobs with valid prompt_logprobs + with self.assertRaises(ValueError): + params = SamplingParams(logprobs=-2, prompt_logprobs=5) + params._verify_args() + + # Test valid logprobs with invalid prompt_logprobs + with self.assertRaises(ValueError): + params = SamplingParams(logprobs=5, prompt_logprobs=-2) + params._verify_args() + + def test_logprobs_boundary_values(self): + """Test boundary values for logprobs""" + # Test just below limit with v1 disabled + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + params = SamplingParams(logprobs=20) + params._verify_args() # Should pass + + # Test just above limit with v1 disabled + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + with self.assertRaises(ValueError): + params = SamplingParams(logprobs=21) + params._verify_args() + + def test_prompt_logprobs_boundary_values(self): + """Test boundary values for prompt_logprobs""" + # Test boundary value -1 (should pass) + params = SamplingParams(prompt_logprobs=-1) + params._verify_args() # Should pass + + # Test boundary value just below -1 (should fail) + with self.assertRaises(ValueError): + params = SamplingParams(prompt_logprobs=-2) + params._verify_args() + + def test_environment_variable_handling(self): + """Test different environment variable values""" + # Test FD_USE_GET_SAVE_OUTPUT_V1 = "0" (default behavior) + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}): + with self.assertRaises(ValueError): + params = SamplingParams(logprobs=21) + params._verify_args() + + # Test FD_USE_GET_SAVE_OUTPUT_V1 = "1" (relaxed behavior) + with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}): + params = SamplingParams(logprobs=21) + params._verify_args() # Should pass + + # Test FD_USE_GET_SAVE_OUTPUT_V1 not set (default to "0") + if "FD_USE_GET_SAVE_OUTPUT_V1" in os.environ: + original_value = os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] + del os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] + else: + original_value = None + + try: + with self.assertRaises(ValueError): + params = SamplingParams(logprobs=21) + params._verify_args() + finally: + if original_value is not None: + os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value + + def test_error_message_formatting(self): + """Test that error messages are properly formatted""" + # Test logprobs error message + with self.assertRaises(ValueError) as cm: + params = SamplingParams(logprobs=-5) + params._verify_args() + + error_msg = str(cm.exception) + self.assertIn("logprobs must be greater than -1", error_msg) + self.assertIn("got -5", error_msg) + + # Test prompt_logprobs error message + with self.assertRaises(ValueError) as cm: + params = SamplingParams(prompt_logprobs=-10) + params._verify_args() + + error_msg = str(cm.exception) + self.assertIn("prompt_logprobs must be greater than or equal to -1", error_msg) + self.assertIn("got -10", error_msg) + + def test_post_init_calls_verify_args(self): + """Test that __post_init__ calls _verify_args""" + # This should call _verify_args internally + params = SamplingParams(logprobs=5, prompt_logprobs=3) + + # The params should be successfully created without errors + self.assertEqual(params.logprobs, 5) + self.assertEqual(params.prompt_logprobs, 3) + + # Test that invalid values are caught during initialization + with self.assertRaises(ValueError): + SamplingParams(logprobs=-2) + + with self.assertRaises(ValueError): + SamplingParams(prompt_logprobs=-2) + + def test_logprobs_with_other_parameters(self): + """Test logprobs validation with other sampling parameters""" + # Test with temperature + params = SamplingParams(logprobs=5, temperature=0.8) + params._verify_args() # Should pass + + # Test with top_p + params = SamplingParams(logprobs=5, top_p=0.9) + params._verify_args() # Should pass + + # Test with all parameters + params = SamplingParams(logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100) + params._verify_args() # Should pass + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/entrypoints/test_vllm_run_engine.py b/tests/entrypoints/test_vllm_run_engine.py new file mode 100644 index 000000000..22783b197 --- /dev/null +++ b/tests/entrypoints/test_vllm_run_engine.py @@ -0,0 +1,100 @@ +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from fastdeploy.engine.sampling_params import SamplingParams +from fastdeploy.entrypoints.llm import LLM +from fastdeploy.worker.output import Logprob, LogprobsTensors + + +class DummyModelConfig: + def __init__(self, max_logprobs=10, ori_vocab_size=50): + self.max_logprobs = max_logprobs + self.ori_vocab_size = ori_vocab_size + + +@pytest.fixture +def mock_llm(): + llm = LLM.__new__(LLM) + llm.llm_engine = MagicMock() + llm.llm_engine.add_requests = MagicMock() + llm.llm_engine.cfg.model_config = DummyModelConfig(max_logprobs=10, ori_vocab_size=100) + # Mock the data_processor.process_logprob_response method to return proper strings + llm.llm_engine.data_processor = MagicMock() + llm.llm_engine.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}" + return llm + + +def test_prompt_logprobs_not_supported_with_stream(mock_llm): + sampling = SamplingParams(prompt_logprobs=5) + with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"): + mock_llm._add_request(["hi"], sampling, stream=True) + + +def test_num_logprobs_exceeds_max(mock_llm): + sampling = SamplingParams(logprobs=20) + with pytest.raises(ValueError, match="Number of logprobs requested"): + mock_llm._add_request(["hi"], sampling) + + +def test_num_prompt_logprobs_exceeds_max(mock_llm): + sampling = SamplingParams(prompt_logprobs=20) + with pytest.raises(ValueError, match="Number of logprobs requested"): + mock_llm._add_request(["hi"], sampling) + + +def test_logprobs_equal_to_minus_one_uses_ori_vocab_size(mock_llm): + sampling = SamplingParams(logprobs=-1) + mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 + mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 30 + mock_llm._add_request(["hi"], sampling) + mock_llm.llm_engine.add_requests.assert_called_once() + # Get the first argument (tasks) which should be a dict + call_args = mock_llm.llm_engine.add_requests.call_args + tasks = call_args[0][0] # First positional argument + assert isinstance(tasks, dict) + assert "prompt" in tasks + assert "request_id" in tasks + + +def test_prompt_logprobs_equal_to_minus_one(mock_llm): + sampling = SamplingParams(prompt_logprobs=-1) + mock_llm.llm_engine.cfg.model_config.max_logprobs = -1 + mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 25 + mock_llm._add_request(["hi"], sampling) + mock_llm.llm_engine.add_requests.assert_called_once() + + +def test_build_prompt_logprobs_basic(mock_llm): + # 构造 2 个 token,每个 token 对应 3 个 logprob 值 + token_ids = np.array([[1, 2, 3], [4, 5, 6]]) + logprobs = np.array([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]) + ranks = np.array([1, 2]) + tensors = LogprobsTensors(token_ids, logprobs, ranks) + + result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=2) + + # 检查结果格式 + assert isinstance(result, list) + assert len(result) == 2 + for pos_dict in result: + assert isinstance(pos_dict, dict) + for logprob_obj in pos_dict.values(): + assert isinstance(logprob_obj, Logprob) + assert logprob_obj.decoded_token.startswith("TOKEN_") + + +def test_build_prompt_logprobs_handles_minus_one(mock_llm): + token_ids = np.array([[7, 8]]) + logprobs = np.array([[-0.9, -1.0]]) + ranks = np.array([1]) + tensors = LogprobsTensors(token_ids, logprobs, ranks) + + result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=-1) + + assert isinstance(result, list) + assert len(result) == 1 + pos_dict = result[0] + assert 7 in pos_dict + assert pos_dict[7].decoded_token == "TOKEN_7" diff --git a/tests/output/test_process_batch_output_use_zmq.py b/tests/output/test_process_batch_output_use_zmq.py new file mode 100644 index 000000000..85e3cf5cf --- /dev/null +++ b/tests/output/test_process_batch_output_use_zmq.py @@ -0,0 +1,135 @@ +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np + +from fastdeploy.engine.request import CompletionOutput, RequestOutput +from fastdeploy.output.token_processor import TokenProcessor +from fastdeploy.worker.output import LogprobsLists + + +class TestTokenProcessorLogprobs(unittest.TestCase): + def setUp(self): + self.cfg = MagicMock() + self.cfg.model_config.enable_logprob = True + self.cfg.speculative_config.method = None + self.cfg.parallel_config.local_data_parallel_id = 0 + self.cached_generated_tokens = MagicMock() + self.engine_worker_queue = MagicMock() + self.split_connector = MagicMock() + + self.processor = TokenProcessor( + self.cfg, self.cached_generated_tokens, self.engine_worker_queue, self.split_connector + ) + + # Mock resource manager + self.processor.resource_manager = MagicMock() + self.processor.resource_manager.stop_flags = [False] + + # Create a proper task mock with time attributes + self.task_mock = MagicMock() + self.task_mock.request_id = "test_request" + self.task_mock.pooling_params = None + self.task_mock.messages = None + self.task_mock.disaggregate_info = None + self.task_mock.eos_token_ids = [2] + self.task_mock.inference_start_time = 100.0 # Set a float value for time calculation + self.task_mock.arrival_time = 90.0 + self.task_mock.preprocess_end_time = 95.0 + self.task_mock.preprocess_start_time = 90.0 + self.task_mock.schedule_start_time = 95.0 + + self.processor.resource_manager.tasks_list = [self.task_mock] + + # Mock logger + self.processor.llm_logger = MagicMock() + + # Mock metrics to avoid prometheus dependency issues + self.processor.main_process_metrics = MagicMock() + self.processor._recycle_resources = MagicMock() + + # Mock the _process_per_token method to avoid prometheus issues + self.processor._process_per_token = MagicMock() + self.processor._process_per_token.return_value = RequestOutput( + request_id="test_request", + outputs=CompletionOutput( + index=0, + send_idx=0, + token_ids=[], + draft_token_ids=[], + ), + finished=False, + metrics=MagicMock(), + ) + + def test_process_logprobs_success(self): + """Test successful logprobs parsing""" + stream_data = MagicMock() + logprobs = MagicMock() + logprobs.tolists.return_value = LogprobsLists( + logprobs=[[0.5]], logprob_token_ids=[[1]], sampled_token_ranks=[0] + ) + stream_data.logprobs = logprobs + stream_data.tokens = np.array([1]) + stream_data.batch_id = 0 + + result = self.processor._process_batch_output_use_zmq([stream_data]) + + self.assertEqual(len(result), 1) + self.processor.llm_logger.warning.assert_not_called() + + def test_process_logprobs_failure(self): + """Test failed logprobs parsing""" + stream_data = MagicMock() + stream_data.logprobs = MagicMock() + stream_data.logprobs.tolists.side_effect = Exception("Test error") + stream_data.tokens = np.array([1]) + stream_data.batch_id = 0 + + with patch.object(self.processor.llm_logger, "warning"): + result = self.processor._process_batch_output_use_zmq([stream_data]) + + self.assertEqual(len(result), 1) + self.assertIsNone(result[0].outputs.logprob) + + def test_process_prompt_logprobs_success(self): + """Test successful prompt_logprobs parsing""" + stream_data = MagicMock() + stream_data.logprobs = None + stream_data.prompt_logprobs = np.array([0.1, 0.2]) + stream_data.tokens = np.array([1]) + stream_data.batch_id = 0 + + result = self.processor._process_batch_output_use_zmq([stream_data]) + + self.assertEqual(len(result), 1) + self.processor.llm_logger.warning.assert_not_called() + + def test_process_prompt_logprobs_failure(self): + """Test failed prompt_logprobs parsing""" + stream_data = MagicMock() + stream_data.logprobs = None + stream_data.prompt_logprobs = MagicMock() + stream_data.prompt_logprobs.tolist.side_effect = AttributeError("'NoneType' object has no attribute 'tolist'") + stream_data.tokens = np.array([1]) + stream_data.batch_id = 0 + + with patch.object(self.processor.llm_logger, "warning"): + result = self.processor._process_batch_output_use_zmq([stream_data]) + + self.assertEqual(len(result), 1) + self.assertIsNone(getattr(result[0], "prompt_logprobs_tensors", None)) + + def test_process_batch_with_stop_flag(self): + """Test processing when stop flag is True""" + self.processor.resource_manager.stop_flags = [True] + stream_data = MagicMock() + stream_data.batch_id = 0 + + result = self.processor._process_batch_output_use_zmq([stream_data]) + + self.assertEqual(len(result), 0) + + +if __name__ == "__main__": + unittest.main()