mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Logprobs]Support prompt_logprobs and max_logprobs (#4897)
* add prompt logprobs * trigger ci * fix unitest * Update fastdeploy/config.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/entrypoints/llm.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fastdeploy/engine/sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/engine/test_sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update tests/engine/test_sampling_params.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix max_logprobs --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -227,8 +227,8 @@ class ModelConfig:
|
||||
self.think_end_id = args.get("think_end_id", -1)
|
||||
self.im_patch_id = args.get("image_patch_id", -1)
|
||||
self.line_break_id = args.get("line_break_id", -1)
|
||||
if self.max_logprobs == -1 and hasattr(self, "vocab_size"):
|
||||
self.max_logprobs = self.vocab_size
|
||||
if self.max_logprobs < -1:
|
||||
raise ValueError(" The possible values for max_logprobs can't be less than -1 ")
|
||||
|
||||
self._post_init()
|
||||
|
||||
|
||||
@@ -29,7 +29,12 @@ from fastdeploy.engine.pooling_params import PoolingParams
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.openai.protocol import ToolCall
|
||||
from fastdeploy.utils import data_processor_logger
|
||||
from fastdeploy.worker.output import LogprobsLists, SampleLogprobs
|
||||
from fastdeploy.worker.output import (
|
||||
LogprobsLists,
|
||||
LogprobsTensors,
|
||||
PromptLogprobs,
|
||||
SampleLogprobs,
|
||||
)
|
||||
|
||||
|
||||
class RequestStatus(Enum):
|
||||
@@ -463,6 +468,8 @@ class RequestOutput:
|
||||
request_id: str,
|
||||
prompt: Optional[str] = None,
|
||||
prompt_token_ids: Optional[list[int]] = None,
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None,
|
||||
prompt_logprobs_tensors: Optional[LogprobsTensors] = None,
|
||||
output_type: Optional[int] = 3,
|
||||
outputs: CompletionOutput = None,
|
||||
finished: bool = False,
|
||||
@@ -476,6 +483,8 @@ class RequestOutput:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.prompt_logprobs = prompt_logprobs
|
||||
self.prompt_logprobs_tensors = prompt_logprobs_tensors
|
||||
self.output_type = output_type
|
||||
self.outputs = outputs
|
||||
self.finished = finished
|
||||
@@ -521,6 +530,7 @@ class RequestOutput:
|
||||
f"RequestOutput(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"prompt_logprobs={self.prompt_logprobs}, "
|
||||
f"output_type={self.output_type}, "
|
||||
f"outputs={self.outputs}, "
|
||||
f"finished={self.finished}, "
|
||||
@@ -546,6 +556,7 @@ class RequestOutput:
|
||||
"request_id": self.request_id,
|
||||
"prompt": self.prompt,
|
||||
"prompt_token_ids": self.prompt_token_ids,
|
||||
"prompt_logprobs": self.prompt_logprobs,
|
||||
"output_type": self.output_type,
|
||||
"outputs": None if self.outputs is None else self.outputs.to_dict(),
|
||||
"metrics": None if self.metrics is None else self.metrics.to_dict(),
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass, fields
|
||||
from enum import Enum
|
||||
@@ -206,10 +207,12 @@ class SamplingParams:
|
||||
raise ValueError(
|
||||
f"min_tokens must be less than or equal to " f"max_tokens={self.max_tokens}, got {self.min_tokens}."
|
||||
)
|
||||
if self.logprobs is not None and self.logprobs < 0:
|
||||
raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.")
|
||||
if self.logprobs is not None and self.logprobs > 20:
|
||||
if self.logprobs is not None and self.logprobs < -1:
|
||||
raise ValueError(f"logprobs must be greater than -1, got {self.logprobs}.")
|
||||
if self.logprobs is not None and self.logprobs > 20 and os.getenv("FD_USE_GET_SAVE_OUTPUT_V1", "0") == "0":
|
||||
raise ValueError("Invalid value for 'top_logprobs': must be less than or equal to 20.")
|
||||
if self.prompt_logprobs is not None and self.prompt_logprobs < -1:
|
||||
raise ValueError(f"prompt_logprobs must be greater than or equal to -1, got {self.prompt_logprobs}.")
|
||||
|
||||
if not 0 <= self.seed <= 922337203685477580:
|
||||
raise ValueError("seed must be in [0, 922337203685477580], got " f"{self.seed}.")
|
||||
|
||||
@@ -16,11 +16,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import ValidationError
|
||||
@@ -37,13 +39,20 @@ from fastdeploy.utils import (
|
||||
llm_logger,
|
||||
retrive_model_from_server,
|
||||
)
|
||||
from fastdeploy.worker.output import Logprob, LogprobsLists
|
||||
from fastdeploy.worker.output import (
|
||||
Logprob,
|
||||
LogprobsLists,
|
||||
LogprobsTensors,
|
||||
PromptLogprobs,
|
||||
)
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
if isinstance(handler, logging.StreamHandler):
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
NONES = itertools.repeat(None)
|
||||
|
||||
|
||||
class LLM:
|
||||
"""
|
||||
@@ -189,12 +198,17 @@ class LLM:
|
||||
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
|
||||
|
||||
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
|
||||
num_prompt_logprobs = (
|
||||
sampling_params[0].prompt_logprobs if sampling_params_len > 1 else sampling_params.prompt_logprobs
|
||||
)
|
||||
|
||||
# get output
|
||||
if stream:
|
||||
return self._run_engine_stream(req_ids, prompts, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
|
||||
else:
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
|
||||
outputs = self._run_engine(
|
||||
req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs, num_prompt_logprobs=num_prompt_logprobs
|
||||
)
|
||||
for i in range(len(outputs)):
|
||||
outputs[i].prompt = prompts[i]
|
||||
return outputs
|
||||
@@ -321,6 +335,27 @@ class LLM:
|
||||
current_sampling_params = sampling_params[i]
|
||||
else:
|
||||
current_sampling_params = sampling_params
|
||||
if kwargs.get("stream") and current_sampling_params.prompt_logprobs is not None:
|
||||
raise ValueError("prompt_logprobs is not supported with streaming.")
|
||||
max_logprobs = self.llm_engine.cfg.model_config.max_logprobs
|
||||
if max_logprobs == -1:
|
||||
max_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
|
||||
if current_sampling_params.logprobs is not None:
|
||||
num_logprobs = current_sampling_params.logprobs
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
|
||||
if num_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Number of logprobs requested ({num_logprobs}) exceeds maximum allowed value ({max_logprobs})."
|
||||
)
|
||||
if current_sampling_params.prompt_logprobs is not None:
|
||||
num_prompt_logprobs = current_sampling_params.prompt_logprobs
|
||||
if num_prompt_logprobs == -1:
|
||||
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
|
||||
if num_prompt_logprobs > max_logprobs:
|
||||
raise ValueError(
|
||||
f"Number of logprobs requested ({num_prompt_logprobs}) exceeds maximum allowed value ({max_logprobs})."
|
||||
)
|
||||
if current_sampling_params.guided_decoding is not None:
|
||||
guided_decoding_dict = current_sampling_params.guided_decoding.to_dict()
|
||||
tasks.update(guided_decoding_dict)
|
||||
@@ -377,7 +412,93 @@ class LLM:
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None):
|
||||
def _build_prompt_logprobs(
|
||||
self,
|
||||
prompt_logprobs_tensors: LogprobsTensors,
|
||||
num_prompt_logprobs: int,
|
||||
):
|
||||
"""Update with prompt logprobs from worker.
|
||||
Args:
|
||||
prompt_logprobs_tensors: tuple containing the prompt logprobs
|
||||
tensors.
|
||||
"""
|
||||
|
||||
token_ids, logprobs, ranks = prompt_logprobs_tensors
|
||||
|
||||
# Detokenize non-incrementally.
|
||||
# Output is flat: [num_tok, num_lps] -> [num_tok * num_lps]
|
||||
decoded_tokens = [self._decode_token(token_id) for token_id in token_ids.flatten().tolist()]
|
||||
|
||||
# Recover shapes.
|
||||
num_prompt_tokens, num_logprobs = logprobs.shape
|
||||
|
||||
# Pythonize the paddle tensors.
|
||||
prompt_token_ranks = ranks.tolist()
|
||||
prompt_logprobs = logprobs.tolist()
|
||||
token_ids = token_ids.tolist()
|
||||
result: Optional[PromptLogprobs] = []
|
||||
# Make Logprob for each position.
|
||||
for pos in range(num_prompt_tokens):
|
||||
# Handle flattening.
|
||||
offset = pos * num_logprobs
|
||||
offset_end = offset + num_logprobs
|
||||
decoded_tokens_for_pos = NONES if decoded_tokens is None else decoded_tokens[offset:offset_end]
|
||||
|
||||
# Update with the Logprob dictionary for this pos.
|
||||
result.append(
|
||||
self._make_logprob_dict(
|
||||
prompt_logprobs[pos],
|
||||
token_ids[pos],
|
||||
decoded_tokens_for_pos,
|
||||
prompt_token_ranks[pos],
|
||||
num_prompt_logprobs,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _make_logprob_dict(
|
||||
logprobs: list[float],
|
||||
logprob_token_ids: list[int],
|
||||
decoded_tokens: Iterable[str | None],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> dict[int, Logprob]:
|
||||
"""Make a Logprob dictionary for a position.
|
||||
Args:
|
||||
logprobs: list of log probabilities
|
||||
logprob_token_ids: list of top token ids
|
||||
decoded_tokens: list of decoded top tokens
|
||||
rank: rank of the sampled token
|
||||
num_logprobs: number of logprobs requested
|
||||
by the user (in addition to sampled logprob)
|
||||
Returns:
|
||||
dict[token id, Logprob]
|
||||
"""
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = len(logprobs)
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank,), topk_ranks)
|
||||
|
||||
return {
|
||||
token_id: Logprob(
|
||||
logprob=logprob,
|
||||
rank=rank,
|
||||
decoded_token=token,
|
||||
)
|
||||
for token_id, logprob, rank, token in zip(logprob_token_ids, logprobs, ranks, decoded_tokens)
|
||||
}
|
||||
|
||||
def _run_engine(
|
||||
self,
|
||||
req_ids: list[str],
|
||||
use_tqdm: bool,
|
||||
topk_logprobs: Optional[int] = None,
|
||||
num_prompt_logprobs: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
运行引擎,并返回结果列表。
|
||||
|
||||
@@ -422,9 +543,17 @@ class LLM:
|
||||
|
||||
# filter logprobs
|
||||
if result.outputs.top_logprobs and topk_logprobs:
|
||||
if topk_logprobs == -1:
|
||||
topk_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
|
||||
result.outputs.logprobs = self._build_sample_logprobs(
|
||||
result.outputs.top_logprobs, topk_logprobs
|
||||
)
|
||||
if result.prompt_logprobs_tensors and num_prompt_logprobs:
|
||||
if num_prompt_logprobs == -1:
|
||||
num_prompt_logprobs = self.llm_engine.cfg.model_config.ori_vocab_size
|
||||
result.prompt_logprobs = self._build_prompt_logprobs(
|
||||
result.prompt_logprobs_tensors, num_prompt_logprobs
|
||||
)
|
||||
|
||||
output[pos] = result
|
||||
finished.append(i)
|
||||
|
||||
@@ -290,6 +290,19 @@ class TokenProcessor:
|
||||
finished=False,
|
||||
metrics=metrics,
|
||||
)
|
||||
if self.use_logprobs:
|
||||
if getattr(stream_data, "logprobs", None) is not None:
|
||||
try:
|
||||
logprobs_list: LogprobsLists = stream_data.logprobs.tolists()
|
||||
result.outputs.logprob = float(logprobs_list.logprobs[0][0])
|
||||
result.outputs.top_logprobs = logprobs_list
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to parse logprobs from StreamTransferData: {e}")
|
||||
if getattr(stream_data, "prompt_logprobs", None) is not None:
|
||||
try:
|
||||
result.prompt_logprobs_tensors = stream_data.prompt_logprobs
|
||||
except Exception as e:
|
||||
llm_logger.warning(f"Failed to parse prompt_logprobs from StreamTransferData: {e}")
|
||||
if self.tokens_counter[task_id] == 0:
|
||||
if task.messages is not None:
|
||||
result.prompt = task.messages
|
||||
|
||||
@@ -30,6 +30,7 @@ class Logprob(NamedTuple):
|
||||
decoded_token: Optional[str] = None
|
||||
|
||||
|
||||
PromptLogprobs = list[dict[int, Logprob] | None]
|
||||
# [{token_id, logprob}] for tokens sampled from the top-k
|
||||
SampleLogprobs = list[dict[int, Logprob]]
|
||||
|
||||
|
||||
222
tests/engine/test_sampling_params.py
Normal file
222
tests/engine/test_sampling_params.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class TestSamplingParamsVerification(unittest.TestCase):
|
||||
"""Test case for SamplingParams _verify_args method"""
|
||||
|
||||
def test_logprobs_valid_values(self):
|
||||
"""Test valid logprobs values"""
|
||||
# Test None value (should pass)
|
||||
params = SamplingParams(logprobs=None)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test -1 value (should pass)
|
||||
params = SamplingParams(logprobs=-1)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test 0 value (should pass)
|
||||
params = SamplingParams(logprobs=0)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test 20 value (should pass when FD_USE_GET_SAVE_OUTPUT_V1 is "0")
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
params = SamplingParams(logprobs=20)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
def test_logprobs_invalid_less_than_minus_one(self):
|
||||
"""Test logprobs less than -1 should raise ValueError"""
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(logprobs=-2)
|
||||
params._verify_args()
|
||||
|
||||
self.assertIn("logprobs must be greater than -1", str(cm.exception))
|
||||
self.assertIn("got -2", str(cm.exception))
|
||||
|
||||
def test_logprobs_greater_than_20_with_v1_disabled(self):
|
||||
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is disabled"""
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(logprobs=21)
|
||||
params._verify_args()
|
||||
|
||||
self.assertEqual("Invalid value for 'top_logprobs': must be less than or equal to 20.", str(cm.exception))
|
||||
|
||||
def test_logprobs_greater_than_20_with_v1_enabled(self):
|
||||
"""Test logprobs greater than 20 when FD_USE_GET_SAVE_OUTPUT_V1 is enabled"""
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
# Should not raise when v1 is enabled
|
||||
params = SamplingParams(logprobs=21)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test even larger values when v1 is enabled
|
||||
params = SamplingParams(logprobs=100)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
def test_prompt_logprobs_valid_values(self):
|
||||
"""Test valid prompt_logprobs values"""
|
||||
# Test None value (should pass)
|
||||
params = SamplingParams(prompt_logprobs=None)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test -1 value (should pass)
|
||||
params = SamplingParams(prompt_logprobs=-1)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test 0 value (should pass)
|
||||
params = SamplingParams(prompt_logprobs=0)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test positive values (should pass)
|
||||
params = SamplingParams(prompt_logprobs=10)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
def test_prompt_logprobs_invalid_less_than_minus_one(self):
|
||||
"""Test prompt_logprobs less than -1 should raise ValueError"""
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(prompt_logprobs=-2)
|
||||
params._verify_args()
|
||||
|
||||
self.assertIn("prompt_logprobs must be greater than or equal to -1", str(cm.exception))
|
||||
self.assertIn("got -2", str(cm.exception))
|
||||
|
||||
def test_combined_logprobs_and_prompt_logprobs(self):
|
||||
"""Test both logprobs and prompt_logprobs together"""
|
||||
# Test valid combination
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
||||
params._verify_args() # Should not raise
|
||||
|
||||
# Test invalid logprobs with valid prompt_logprobs
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(logprobs=-2, prompt_logprobs=5)
|
||||
params._verify_args()
|
||||
|
||||
# Test valid logprobs with invalid prompt_logprobs
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=-2)
|
||||
params._verify_args()
|
||||
|
||||
def test_logprobs_boundary_values(self):
|
||||
"""Test boundary values for logprobs"""
|
||||
# Test just below limit with v1 disabled
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
params = SamplingParams(logprobs=20)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
# Test just above limit with v1 disabled
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(logprobs=21)
|
||||
params._verify_args()
|
||||
|
||||
def test_prompt_logprobs_boundary_values(self):
|
||||
"""Test boundary values for prompt_logprobs"""
|
||||
# Test boundary value -1 (should pass)
|
||||
params = SamplingParams(prompt_logprobs=-1)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
# Test boundary value just below -1 (should fail)
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(prompt_logprobs=-2)
|
||||
params._verify_args()
|
||||
|
||||
def test_environment_variable_handling(self):
|
||||
"""Test different environment variable values"""
|
||||
# Test FD_USE_GET_SAVE_OUTPUT_V1 = "0" (default behavior)
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "0"}):
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(logprobs=21)
|
||||
params._verify_args()
|
||||
|
||||
# Test FD_USE_GET_SAVE_OUTPUT_V1 = "1" (relaxed behavior)
|
||||
with patch.dict(os.environ, {"FD_USE_GET_SAVE_OUTPUT_V1": "1"}):
|
||||
params = SamplingParams(logprobs=21)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
# Test FD_USE_GET_SAVE_OUTPUT_V1 not set (default to "0")
|
||||
if "FD_USE_GET_SAVE_OUTPUT_V1" in os.environ:
|
||||
original_value = os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
|
||||
del os.environ["FD_USE_GET_SAVE_OUTPUT_V1"]
|
||||
else:
|
||||
original_value = None
|
||||
|
||||
try:
|
||||
with self.assertRaises(ValueError):
|
||||
params = SamplingParams(logprobs=21)
|
||||
params._verify_args()
|
||||
finally:
|
||||
if original_value is not None:
|
||||
os.environ["FD_USE_GET_SAVE_OUTPUT_V1"] = original_value
|
||||
|
||||
def test_error_message_formatting(self):
|
||||
"""Test that error messages are properly formatted"""
|
||||
# Test logprobs error message
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(logprobs=-5)
|
||||
params._verify_args()
|
||||
|
||||
error_msg = str(cm.exception)
|
||||
self.assertIn("logprobs must be greater than -1", error_msg)
|
||||
self.assertIn("got -5", error_msg)
|
||||
|
||||
# Test prompt_logprobs error message
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
params = SamplingParams(prompt_logprobs=-10)
|
||||
params._verify_args()
|
||||
|
||||
error_msg = str(cm.exception)
|
||||
self.assertIn("prompt_logprobs must be greater than or equal to -1", error_msg)
|
||||
self.assertIn("got -10", error_msg)
|
||||
|
||||
def test_post_init_calls_verify_args(self):
|
||||
"""Test that __post_init__ calls _verify_args"""
|
||||
# This should call _verify_args internally
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=3)
|
||||
|
||||
# The params should be successfully created without errors
|
||||
self.assertEqual(params.logprobs, 5)
|
||||
self.assertEqual(params.prompt_logprobs, 3)
|
||||
|
||||
# Test that invalid values are caught during initialization
|
||||
with self.assertRaises(ValueError):
|
||||
SamplingParams(logprobs=-2)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
SamplingParams(prompt_logprobs=-2)
|
||||
|
||||
def test_logprobs_with_other_parameters(self):
|
||||
"""Test logprobs validation with other sampling parameters"""
|
||||
# Test with temperature
|
||||
params = SamplingParams(logprobs=5, temperature=0.8)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
# Test with top_p
|
||||
params = SamplingParams(logprobs=5, top_p=0.9)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
# Test with all parameters
|
||||
params = SamplingParams(logprobs=5, prompt_logprobs=3, temperature=0.8, top_p=0.9, top_k=50, max_tokens=100)
|
||||
params._verify_args() # Should pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
100
tests/entrypoints/test_vllm_run_engine.py
Normal file
100
tests/entrypoints/test_vllm_run_engine.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from fastdeploy.engine.sampling_params import SamplingParams
|
||||
from fastdeploy.entrypoints.llm import LLM
|
||||
from fastdeploy.worker.output import Logprob, LogprobsTensors
|
||||
|
||||
|
||||
class DummyModelConfig:
|
||||
def __init__(self, max_logprobs=10, ori_vocab_size=50):
|
||||
self.max_logprobs = max_logprobs
|
||||
self.ori_vocab_size = ori_vocab_size
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
llm = LLM.__new__(LLM)
|
||||
llm.llm_engine = MagicMock()
|
||||
llm.llm_engine.add_requests = MagicMock()
|
||||
llm.llm_engine.cfg.model_config = DummyModelConfig(max_logprobs=10, ori_vocab_size=100)
|
||||
# Mock the data_processor.process_logprob_response method to return proper strings
|
||||
llm.llm_engine.data_processor = MagicMock()
|
||||
llm.llm_engine.data_processor.process_logprob_response.side_effect = lambda ids, **kwargs: f"TOKEN_{ids[0]}"
|
||||
return llm
|
||||
|
||||
|
||||
def test_prompt_logprobs_not_supported_with_stream(mock_llm):
|
||||
sampling = SamplingParams(prompt_logprobs=5)
|
||||
with pytest.raises(ValueError, match="prompt_logprobs is not supported with streaming"):
|
||||
mock_llm._add_request(["hi"], sampling, stream=True)
|
||||
|
||||
|
||||
def test_num_logprobs_exceeds_max(mock_llm):
|
||||
sampling = SamplingParams(logprobs=20)
|
||||
with pytest.raises(ValueError, match="Number of logprobs requested"):
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
|
||||
|
||||
def test_num_prompt_logprobs_exceeds_max(mock_llm):
|
||||
sampling = SamplingParams(prompt_logprobs=20)
|
||||
with pytest.raises(ValueError, match="Number of logprobs requested"):
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
|
||||
|
||||
def test_logprobs_equal_to_minus_one_uses_ori_vocab_size(mock_llm):
|
||||
sampling = SamplingParams(logprobs=-1)
|
||||
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
|
||||
mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 30
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
mock_llm.llm_engine.add_requests.assert_called_once()
|
||||
# Get the first argument (tasks) which should be a dict
|
||||
call_args = mock_llm.llm_engine.add_requests.call_args
|
||||
tasks = call_args[0][0] # First positional argument
|
||||
assert isinstance(tasks, dict)
|
||||
assert "prompt" in tasks
|
||||
assert "request_id" in tasks
|
||||
|
||||
|
||||
def test_prompt_logprobs_equal_to_minus_one(mock_llm):
|
||||
sampling = SamplingParams(prompt_logprobs=-1)
|
||||
mock_llm.llm_engine.cfg.model_config.max_logprobs = -1
|
||||
mock_llm.llm_engine.cfg.model_config.ori_vocab_size = 25
|
||||
mock_llm._add_request(["hi"], sampling)
|
||||
mock_llm.llm_engine.add_requests.assert_called_once()
|
||||
|
||||
|
||||
def test_build_prompt_logprobs_basic(mock_llm):
|
||||
# 构造 2 个 token,每个 token 对应 3 个 logprob 值
|
||||
token_ids = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
logprobs = np.array([[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]])
|
||||
ranks = np.array([1, 2])
|
||||
tensors = LogprobsTensors(token_ids, logprobs, ranks)
|
||||
|
||||
result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=2)
|
||||
|
||||
# 检查结果格式
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
for pos_dict in result:
|
||||
assert isinstance(pos_dict, dict)
|
||||
for logprob_obj in pos_dict.values():
|
||||
assert isinstance(logprob_obj, Logprob)
|
||||
assert logprob_obj.decoded_token.startswith("TOKEN_")
|
||||
|
||||
|
||||
def test_build_prompt_logprobs_handles_minus_one(mock_llm):
|
||||
token_ids = np.array([[7, 8]])
|
||||
logprobs = np.array([[-0.9, -1.0]])
|
||||
ranks = np.array([1])
|
||||
tensors = LogprobsTensors(token_ids, logprobs, ranks)
|
||||
|
||||
result = mock_llm._build_prompt_logprobs(tensors, num_prompt_logprobs=-1)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
pos_dict = result[0]
|
||||
assert 7 in pos_dict
|
||||
assert pos_dict[7].decoded_token == "TOKEN_7"
|
||||
135
tests/output/test_process_batch_output_use_zmq.py
Normal file
135
tests/output/test_process_batch_output_use_zmq.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.engine.request import CompletionOutput, RequestOutput
|
||||
from fastdeploy.output.token_processor import TokenProcessor
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
class TestTokenProcessorLogprobs(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.cfg = MagicMock()
|
||||
self.cfg.model_config.enable_logprob = True
|
||||
self.cfg.speculative_config.method = None
|
||||
self.cfg.parallel_config.local_data_parallel_id = 0
|
||||
self.cached_generated_tokens = MagicMock()
|
||||
self.engine_worker_queue = MagicMock()
|
||||
self.split_connector = MagicMock()
|
||||
|
||||
self.processor = TokenProcessor(
|
||||
self.cfg, self.cached_generated_tokens, self.engine_worker_queue, self.split_connector
|
||||
)
|
||||
|
||||
# Mock resource manager
|
||||
self.processor.resource_manager = MagicMock()
|
||||
self.processor.resource_manager.stop_flags = [False]
|
||||
|
||||
# Create a proper task mock with time attributes
|
||||
self.task_mock = MagicMock()
|
||||
self.task_mock.request_id = "test_request"
|
||||
self.task_mock.pooling_params = None
|
||||
self.task_mock.messages = None
|
||||
self.task_mock.disaggregate_info = None
|
||||
self.task_mock.eos_token_ids = [2]
|
||||
self.task_mock.inference_start_time = 100.0 # Set a float value for time calculation
|
||||
self.task_mock.arrival_time = 90.0
|
||||
self.task_mock.preprocess_end_time = 95.0
|
||||
self.task_mock.preprocess_start_time = 90.0
|
||||
self.task_mock.schedule_start_time = 95.0
|
||||
|
||||
self.processor.resource_manager.tasks_list = [self.task_mock]
|
||||
|
||||
# Mock logger
|
||||
self.processor.llm_logger = MagicMock()
|
||||
|
||||
# Mock metrics to avoid prometheus dependency issues
|
||||
self.processor.main_process_metrics = MagicMock()
|
||||
self.processor._recycle_resources = MagicMock()
|
||||
|
||||
# Mock the _process_per_token method to avoid prometheus issues
|
||||
self.processor._process_per_token = MagicMock()
|
||||
self.processor._process_per_token.return_value = RequestOutput(
|
||||
request_id="test_request",
|
||||
outputs=CompletionOutput(
|
||||
index=0,
|
||||
send_idx=0,
|
||||
token_ids=[],
|
||||
draft_token_ids=[],
|
||||
),
|
||||
finished=False,
|
||||
metrics=MagicMock(),
|
||||
)
|
||||
|
||||
def test_process_logprobs_success(self):
|
||||
"""Test successful logprobs parsing"""
|
||||
stream_data = MagicMock()
|
||||
logprobs = MagicMock()
|
||||
logprobs.tolists.return_value = LogprobsLists(
|
||||
logprobs=[[0.5]], logprob_token_ids=[[1]], sampled_token_ranks=[0]
|
||||
)
|
||||
stream_data.logprobs = logprobs
|
||||
stream_data.tokens = np.array([1])
|
||||
stream_data.batch_id = 0
|
||||
|
||||
result = self.processor._process_batch_output_use_zmq([stream_data])
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.processor.llm_logger.warning.assert_not_called()
|
||||
|
||||
def test_process_logprobs_failure(self):
|
||||
"""Test failed logprobs parsing"""
|
||||
stream_data = MagicMock()
|
||||
stream_data.logprobs = MagicMock()
|
||||
stream_data.logprobs.tolists.side_effect = Exception("Test error")
|
||||
stream_data.tokens = np.array([1])
|
||||
stream_data.batch_id = 0
|
||||
|
||||
with patch.object(self.processor.llm_logger, "warning"):
|
||||
result = self.processor._process_batch_output_use_zmq([stream_data])
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsNone(result[0].outputs.logprob)
|
||||
|
||||
def test_process_prompt_logprobs_success(self):
|
||||
"""Test successful prompt_logprobs parsing"""
|
||||
stream_data = MagicMock()
|
||||
stream_data.logprobs = None
|
||||
stream_data.prompt_logprobs = np.array([0.1, 0.2])
|
||||
stream_data.tokens = np.array([1])
|
||||
stream_data.batch_id = 0
|
||||
|
||||
result = self.processor._process_batch_output_use_zmq([stream_data])
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.processor.llm_logger.warning.assert_not_called()
|
||||
|
||||
def test_process_prompt_logprobs_failure(self):
|
||||
"""Test failed prompt_logprobs parsing"""
|
||||
stream_data = MagicMock()
|
||||
stream_data.logprobs = None
|
||||
stream_data.prompt_logprobs = MagicMock()
|
||||
stream_data.prompt_logprobs.tolist.side_effect = AttributeError("'NoneType' object has no attribute 'tolist'")
|
||||
stream_data.tokens = np.array([1])
|
||||
stream_data.batch_id = 0
|
||||
|
||||
with patch.object(self.processor.llm_logger, "warning"):
|
||||
result = self.processor._process_batch_output_use_zmq([stream_data])
|
||||
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertIsNone(getattr(result[0], "prompt_logprobs_tensors", None))
|
||||
|
||||
def test_process_batch_with_stop_flag(self):
|
||||
"""Test processing when stop flag is True"""
|
||||
self.processor.resource_manager.stop_flags = [True]
|
||||
stream_data = MagicMock()
|
||||
stream_data.batch_id = 0
|
||||
|
||||
result = self.processor._process_batch_output_use_zmq([stream_data])
|
||||
|
||||
self.assertEqual(len(result), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user