[Feature] General support for logprobs (#2974)

* [Feature] support logprobs in chat/completions and completions endpoints

* Temporarily comment out text_offset due to incorrect logic

* Clean up temporary debug prints

* [Feature] support logprobs in offline mode via SamplingParams

* fix: serialize Logprob as dict before zmq send to fix msgpack error

* refactor: remove redundant methods to simplify codebase

* Fix missing fields in CompletionOutput.to_dict affecting msgpack serialization

* refactor: centralize param validation in engine_client to reduce duplication

* revert: rollback changes in offline_demo.py

* revert: rollback changes in offline_demo.py

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

* [bugfix] fix parameter validation for logprobs

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
SunLei
2025-07-31 20:25:56 +08:00
committed by GitHub
parent fe17410f9c
commit dade19d7a4
10 changed files with 330 additions and 44 deletions

View File

@@ -17,7 +17,7 @@
import asyncio
import time
import uuid
from typing import List
from typing import List, Optional
import aiozmq
import msgpack
@@ -26,6 +26,7 @@ from aiozmq import zmq
from fastdeploy.engine.request import RequestOutput
from fastdeploy.entrypoints.openai.protocol import (
CompletionLogprobs,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
@@ -35,6 +36,7 @@ from fastdeploy.entrypoints.openai.protocol import (
UsageInfo,
)
from fastdeploy.utils import api_server_logger, get_host_ip
from fastdeploy.worker.output import LogprobsLists
class OpenAIServingCompletion:
@@ -160,6 +162,8 @@ class OpenAIServingCompletion:
valid_results = [dict()] * num_choices
output_tokens = [0] * num_choices
aggregated_top_logprobs = [[[], [], []]] * num_choices
aggregated_token_ids = [[]] * num_choices
completion_batched_token_ids = [[] for _ in range(num_choices)]
current_waiting_time = 0
while num_choices > 0:
@@ -182,6 +186,15 @@ class OpenAIServingCompletion:
if data.get("error_code", 200) != 200:
raise ValueError("{}".format(data["error_msg"]))
output = data["outputs"]
output_top_logprobs = output["top_logprobs"]
if output_top_logprobs is not None:
aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0])
aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1])
aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2])
aggregated_token_ids[rid].extend(data["outputs"]["token_ids"])
self.engine_client.data_processor.process_response_dict(
data, stream=False, include_stop_str_in_output=request.include_stop_str_in_output
)
@@ -189,6 +202,8 @@ class OpenAIServingCompletion:
completion_batched_token_ids[rid].extend(data["outputs"]["token_ids"])
if data.get("finished", False):
data["output_token_ids"] = output_tokens[rid]
data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid]
data["outputs"]["token_ids"] = aggregated_token_ids[rid]
valid_results[rid] = data
num_choices -= 1
break
@@ -292,6 +307,10 @@ class OpenAIServingCompletion:
arrival_time = res["metrics"]["arrival_time"] - inference_start_time[idx]
output = res["outputs"]
output_top_logprobs = output["top_logprobs"]
logprobs_res: Optional[CompletionLogprobs] = None
if request.logprobs and output_top_logprobs is not None:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
choices.append(
CompletionResponseStreamChoice(
@@ -302,6 +321,7 @@ class OpenAIServingCompletion:
tool_calls=output.get("tool_call_content"),
reasoning_content=output.get("reasoning_content"),
arrival_time=arrival_time,
logprobs=logprobs_res,
)
)
if res["finished"]:
@@ -367,6 +387,7 @@ class OpenAIServingCompletion:
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
aggregated_logprobs: Optional[CompletionLogprobs] = None
for idx in range(len(final_res_batch)):
final_res = final_res_batch[idx]
@@ -376,6 +397,18 @@ class OpenAIServingCompletion:
completion_token_ids = completion_batched_token_ids[idx]
output = final_res["outputs"]
output_top_logprobs = output["top_logprobs"]
if output_top_logprobs is not None:
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
if aggregated_logprobs is None:
aggregated_logprobs = logprobs_res
else:
aggregated_logprobs.tokens.extend(logprobs_res.tokens)
aggregated_logprobs.token_logprobs.extend(logprobs_res.token_logprobs)
aggregated_logprobs.top_logprobs.extend(logprobs_res.top_logprobs)
aggregated_logprobs.text_offset.extend(logprobs_res.text_offset)
if request.echo:
assert prompt_text is not None
if request.max_tokens == 0:
@@ -396,7 +429,7 @@ class OpenAIServingCompletion:
completion_token_ids=completion_token_ids if request.return_token_ids else None,
reasoning_content=output.get("reasoning_content"),
tool_calls=output.get("tool_call_content"),
logprobs=None,
logprobs=aggregated_logprobs,
finish_reason=None,
)
choices.append(choice_data)
@@ -419,3 +452,99 @@ class OpenAIServingCompletion:
choices=choices,
usage=usage,
)
def _create_completion_logprobs(
self,
output_top_logprobs,
request_logprobs: Optional[int] = None,
prompt_text_offset: Optional[int] = None,
) -> Optional[CompletionLogprobs]:
"""Create OpenAI-style logprobs for completions."""
# Parameter validation
if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs):
return None
logprobs_res: Optional[CompletionLogprobs] = None
# Iterate over the top-k candidates for each token
for logprob_token_ids, logprobs, sampled_token_ranks in zip(
output_top_logprobs[0], output_top_logprobs[1], output_top_logprobs[2]
):
top_logprobs = LogprobsLists(
logprob_token_ids=[logprob_token_ids],
logprobs=[logprobs],
sampled_token_ranks=[sampled_token_ranks],
)
# Build the logprobs response
step_logprobs_res = self._build_logprobs_response(
response_logprobs=top_logprobs,
request_top_logprobs=request_logprobs,
prompt_text_offset=prompt_text_offset,
)
if logprobs_res is None:
logprobs_res = step_logprobs_res
else:
# Append the new tokens to the existing logprobs response
logprobs_res.tokens.extend(step_logprobs_res.tokens)
logprobs_res.token_logprobs.extend(step_logprobs_res.token_logprobs)
logprobs_res.top_logprobs.extend(step_logprobs_res.top_logprobs)
return logprobs_res
def _build_logprobs_response(
self,
response_logprobs: Optional[LogprobsLists] = None,
request_top_logprobs: Optional[int] = None,
prompt_text_offset: Optional[int] = None,
) -> Optional[CompletionLogprobs]:
"""
Construct a logprobs response object in line with the OpenAI style.
Retain the complete top-k candidates and avoid circular references.
"""
# Parameter validation
if response_logprobs is None or request_top_logprobs is None or request_top_logprobs < 0:
return None
try:
# The top-k candidates for the current token
topk_token_ids = []
topk_logprobs = []
if response_logprobs.logprob_token_ids and len(response_logprobs.logprob_token_ids) > 0:
topk_token_ids = response_logprobs.logprob_token_ids[0][: request_top_logprobs + 1]
if response_logprobs.logprobs and len(response_logprobs.logprobs) > 0:
topk_logprobs = response_logprobs.logprobs[0][: request_top_logprobs + 1]
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
tokens = []
token_logprobs = []
top_logprobs = {}
idx = 0
for tid, lp in zip(topk_token_ids, topk_logprobs):
token_str = self.engine_client.data_processor.process_logprob_response(
[tid], clean_up_tokenization_spaces=False
)
if "\ufffd" in token_str:
token_bytes = token_str.encode("utf-8", errors="replace")
token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
if idx == 0:
tokens.append(token_str)
token_logprobs.append(lp)
else:
top_logprobs[token_str] = lp
idx += 1
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
# text_offset = prompt_text_offset + len(tokens) - 1
return CompletionLogprobs(
tokens=tokens,
token_logprobs=token_logprobs,
top_logprobs=[top_logprobs],
# text_offset=[text_offset],
)
except Exception as e:
api_server_logger.error("Error in _build_logprobs_response: %s", e)
return None