[Feature] General support for logprobs (#2974)

* [Feature] support logprobs in chat/completions and completions endpoints

* Temporarily comment out text_offset due to incorrect logic

* Clean up temporary debug prints

* [Feature] support logprobs in offline mode via SamplingParams

* fix: serialize Logprob as dict before zmq send to fix msgpack error

* refactor: remove redundant methods to simplify codebase

* Fix missing fields in CompletionOutput.to_dict affecting msgpack serialization

* refactor: centralize param validation in engine_client to reduce duplication

* revert: rollback changes in offline_demo.py

* revert: rollback changes in offline_demo.py

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
SunLei
2025-07-31 20:25:56 +08:00
committed by GitHub
parent fe17410f9c
commit dade19d7a4
10 changed files with 330 additions and 44 deletions

View File

@@ -20,6 +20,20 @@ from typing import NamedTuple, Optional
import paddle
class Logprob(NamedTuple):
"""
A named tuple containing information about a token's log probability.
"""
logprob: float
rank: Optional[int] = None
decoded_token: Optional[str] = None
# [{token_id, logprob}] for tokens sampled from the top-k
SampleLogprobs = list[dict[int, Logprob]]
class LogprobsLists(NamedTuple):
""" """
@@ -38,6 +52,17 @@ class LogprobsLists(NamedTuple):
self.sampled_token_ranks[start:end],
)
def slice_columns(self, start: int, end: int):
"""
Slice columns (per-row top-k logprobs and token IDs).
Keeps the number of requests unchanged.
"""
return LogprobsLists(
[row[start:end] for row in self.logprob_token_ids],
[row[start:end] for row in self.logprobs],
self.sampled_token_ranks, # unchanged
)
class LogprobsTensors(NamedTuple):
""" """