mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] General support for logprobs (#2974)
* [Feature] support logprobs in chat/completions and completions endpoints * Temporarily comment out text_offset due to incorrect logic * Clean up temporary debug prints * [Feature] support logprobs in offline mode via SamplingParams * fix: serialize Logprob as dict before zmq send to fix msgpack error * refactor: remove redundant methods to simplify codebase * Fix missing fields in CompletionOutput.to_dict affecting msgpack serialization * refactor: centralize param validation in engine_client to reduce duplication * revert: rollback changes in offline_demo.py * revert: rollback changes in offline_demo.py * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from fastdeploy.engine.sampling_params import SamplingParams
|
||||
|
||||
# from fastdeploy.entrypoints.chat_utils import ChatCompletionMessageParam
|
||||
from fastdeploy.utils import llm_logger, retrive_model_from_server
|
||||
from fastdeploy.worker.output import Logprob, LogprobsLists
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
for handler in root_logger.handlers[:]:
|
||||
@@ -68,12 +69,14 @@ class LLM:
|
||||
model: str,
|
||||
revision: Optional[str] = "master",
|
||||
tokenizer: Optional[str] = None,
|
||||
enable_logprob: Optional[bool] = False,
|
||||
**kwargs,
|
||||
):
|
||||
model = retrive_model_from_server(model, revision)
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
enable_logprob=enable_logprob,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -169,8 +172,10 @@ class LLM:
|
||||
|
||||
req_ids = self._add_request(prompts=prompts, sampling_params=sampling_params)
|
||||
|
||||
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
|
||||
|
||||
# get output
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
|
||||
for i in range(len(outputs)):
|
||||
outputs[i].prompt = prompts[i]
|
||||
return outputs
|
||||
@@ -223,8 +228,10 @@ class LLM:
|
||||
chat_template_kwargs=chat_template_kwargs,
|
||||
)
|
||||
|
||||
topk_logprobs = sampling_params[0].logprobs if sampling_params_len > 1 else sampling_params.logprobs
|
||||
|
||||
# get output
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm)
|
||||
outputs = self._run_engine(req_ids, use_tqdm=use_tqdm, topk_logprobs=topk_logprobs)
|
||||
return outputs
|
||||
|
||||
def _add_request(
|
||||
@@ -278,7 +285,50 @@ class LLM:
|
||||
self.llm_engine.add_requests(tasks, current_sampling_params, enable_thinking=enable_thinking)
|
||||
return req_ids
|
||||
|
||||
def _run_engine(self, req_ids: list[str], use_tqdm: bool):
|
||||
def _build_sample_logprobs(self, logprobs_lists: LogprobsLists, topk_logprobs: int) -> list[dict[int, Logprob]]:
|
||||
"""
|
||||
Constructs a list of dictionaries mapping token IDs to Logprob objects,
|
||||
based on sliced LogprobsLists data (excluding the sampled token at index 0).
|
||||
|
||||
Args:
|
||||
logprobs_lists (LogprobsLists): Contains top-k token IDs, logprobs, and sampled ranks.
|
||||
max_num (int): Maximum number of top logprobs to include (excluding sampled token at index 0).
|
||||
|
||||
Returns:
|
||||
list[dict[int, Logprob]]: One dict per request, mapping token ID to Logprob.
|
||||
"""
|
||||
try:
|
||||
llm_logger.info(f"filter logprobs, topk_logprobs: {topk_logprobs}")
|
||||
if not logprobs_lists.logprob_token_ids:
|
||||
llm_logger.warning("Empty logprob_token_ids in LogprobsLists")
|
||||
return None
|
||||
|
||||
# exclude sampled token at index 0
|
||||
available_topk = len(logprobs_lists.logprob_token_ids[0]) - 1
|
||||
effective_topk_logprobs = min(topk_logprobs, available_topk)
|
||||
|
||||
if effective_topk_logprobs <= 0:
|
||||
llm_logger.warning(
|
||||
f"Invalid effective_topk_logprobs={effective_topk_logprobs}, "
|
||||
f"available_topk={available_topk}, topk_logprobs={topk_logprobs}; returning empty result."
|
||||
)
|
||||
return None
|
||||
|
||||
# sliced 1 ~ (1 + effective_topk_logprobs)
|
||||
sliced_logprobs_lists = logprobs_lists.slice_columns(1, 1 + effective_topk_logprobs)
|
||||
result = []
|
||||
for token_ids, logprobs in zip(sliced_logprobs_lists.logprob_token_ids, sliced_logprobs_lists.logprobs):
|
||||
logprob_dict = {
|
||||
token_id: Logprob(logprob=logprob, rank=i + 1, decoded_token=None)
|
||||
for i, (token_id, logprob) in enumerate(zip(token_ids, logprobs))
|
||||
}
|
||||
result.append(logprob_dict)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Error building sample logprobs from LogprobsLists: {e}")
|
||||
|
||||
def _run_engine(self, req_ids: list[str], use_tqdm: bool, topk_logprobs: Optional[int] = None):
|
||||
"""
|
||||
运行引擎,并返回结果列表。
|
||||
|
||||
@@ -320,6 +370,13 @@ class LLM:
|
||||
|
||||
result = self.req_output.pop(req_id)
|
||||
result = self.llm_engine.data_processor.process_response(result)
|
||||
|
||||
# filter logprobs
|
||||
if result.outputs.top_logprobs and topk_logprobs:
|
||||
result.outputs.logprobs = self._build_sample_logprobs(
|
||||
result.outputs.top_logprobs, topk_logprobs
|
||||
)
|
||||
|
||||
output[pos] = result
|
||||
finished.append(i)
|
||||
|
||||
|
Reference in New Issue
Block a user