[Feature] General support for logprobs (#2974)

* [Feature] support logprobs in chat/completions and completions endpoints

* Temporarily comment out text_offset due to incorrect logic

* Clean up temporary debug prints

* [Feature] support logprobs in offline mode via SamplingParams

* fix: serialize Logprob as dict before zmq send to fix msgpack error

* refactor: remove redundant methods to simplify codebase

* Fix missing fields in CompletionOutput.to_dict affecting msgpack serialization

* refactor: centralize param validation in engine_client to reduce duplication

* revert: rollback changes in offline_demo.py

* revert: rollback changes in offline_demo.py

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
SunLei
2025-07-31 20:25:56 +08:00
committed by GitHub
parent fe17410f9c
commit dade19d7a4
10 changed files with 330 additions and 44 deletions

View File

@@ -31,6 +31,7 @@ from fastdeploy.engine.sampling_params import SamplingParams
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
from fastdeploy.utils import llm_logger, retrive_model_from_server
from fastdeploy.worker.output import Logprob, LogprobsLists
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
@@ -68,12 +69,14 @@ class LLM:
model: str,
revision: Optional[str] = "master",
tokenizer: Optional[str] = None,
enable_logprob: Optional[bool] = False,
**kwargs,
):
model = retrive_model_from_server(model, revision)
engine_args = EngineArgs(
model=model,
tokenizer=tokenizer,
enable_logprob=enable_logprob,
**kwargs,
)
@@ -169,8 +172,10 @@ class LLM:
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
for i in range(len(outputs)):
outputs[i].prompt = prompts[i]
return outputs
@@ -223,8 +228,10 @@ class LLM:
chat_template_kwargs=chat_template_kwargs,
)
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
# get output
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
return outputs
def _add_request(
@@ -278,7 +285,50 @@ class LLM:
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
return req_ids
def _run_engine(self, req_ids: list[str], use_tqdm: bool):
def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]:
"""
Constructs a list of dictionaries mapping token IDs to Logprob objects,
based on sliced LogprobsLists data (excluding the sampled token at index 0).
Args:
logprobs_lists (LogprobsLists): Contains top-k token IDs, logprobs, and sampled ranks.
max_num (int): Maximum number of top logprobs to include (excluding sampled token at index 0).
Returns:
list[dict[int, Logprob]]: One dict per request, mapping token ID to Logprob.
"""
try:
llm_logger.info(f"filter logprobs, topk_logprobs: {topk_logprobs}")
if not logprobs_lists.logprob_token_ids:
llm_logger.warning("Empty logprob_token_ids in LogprobsLists")
return None
# exclude sampled token at index 0
available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1
effective_topk_logprobs = min(topk_logprobs, available_topk)
if effective_topk_logprobs <= 0:
llm_logger.warning(
f"Invalid effective_topk_logprobs={effective_topk_logprobs}, "
f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result."
)
return None
# sliced 1 ~ (1 + effective_topk_logprobs)
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
result = []
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
logprob_dict = {
token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=None)
for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs))
}
result.append(logprob_dict)
return result
except Exception as e:
llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}")
def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None):
"""
运行引擎,并返回结果列表。
@@ -320,6 +370,13 @@ class LLM:
result = self.req_output.pop(req_id)
result = self.llm_engine.data_processor.process_response(result)
# filter logprobs
if result.outputs.top_logprobs and topk_logprobs:
result.outputs.logprobs = self._build_sample_logprobs(
result.outputs.top_logprobs, topk_logprobs
)
output[pos] = result
finished.append(i)