mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Feature] General support for logprobs (#2974)
* [Feature] support logprobs in chat/completions and completions endpoints * Temporarily comment out text_offset due to incorrect logic * Clean up temporary debug prints * [Feature] support logprobs in offline mode via SamplingParams * fix: serialize Logprob as dict before zmq send to fix msgpack error * refactor: remove redundant methods to simplify codebase * Fix missing fields in CompletionOutput.to_dict affecting msgpack serialization * refactor: centralize param validation in engine_client to reduce duplication * revert: rollback changes in offline_demo.py * revert: rollback changes in offline_demo.py * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs * [bugfix] fix parameter validation for logprobs --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiozmq
|
||||
import msgpack
|
||||
@@ -26,6 +26,7 @@ from aiozmq import zmq
|
||||
|
||||
from fastdeploy.engine.request import RequestOutput
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
CompletionLogprobs,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
@@ -35,6 +36,7 @@ from fastdeploy.entrypoints.openai.protocol import (
|
||||
UsageInfo,
|
||||
)
|
||||
from fastdeploy.utils import api_server_logger, get_host_ip
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
class OpenAIServingCompletion:
|
||||
@@ -160,6 +162,8 @@ class OpenAIServingCompletion:
|
||||
|
||||
valid_results = [dict()] * num_choices
|
||||
output_tokens = [0] * num_choices
|
||||
aggregated_top_logprobs = [[[], [], []]] * num_choices
|
||||
aggregated_token_ids = [[]] * num_choices
|
||||
completion_batched_token_ids = [[] for _ in range(num_choices)]
|
||||
current_waiting_time = 0
|
||||
while num_choices > 0:
|
||||
@@ -182,6 +186,15 @@ class OpenAIServingCompletion:
|
||||
if data.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(data["error_msg"]))
|
||||
|
||||
output = data["outputs"]
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
if output_top_logprobs is not None:
|
||||
aggregated_top_logprobs[rid][0].extend(output_top_logprobs[0])
|
||||
aggregated_top_logprobs[rid][1].extend(output_top_logprobs[1])
|
||||
aggregated_top_logprobs[rid][2].extend(output_top_logprobs[2])
|
||||
|
||||
aggregated_token_ids[rid].extend(data["outputs"]["token_ids"])
|
||||
|
||||
self.engine_client.data_processor.process_response_dict(
|
||||
data, stream=False, include_stop_str_in_output=request.include_stop_str_in_output
|
||||
)
|
||||
@@ -189,6 +202,8 @@ class OpenAIServingCompletion:
|
||||
completion_batched_token_ids[rid].extend(data["outputs"]["token_ids"])
|
||||
if data.get("finished", False):
|
||||
data["output_token_ids"] = output_tokens[rid]
|
||||
data["outputs"]["top_logprobs"] = aggregated_top_logprobs[rid]
|
||||
data["outputs"]["token_ids"] = aggregated_token_ids[rid]
|
||||
valid_results[rid] = data
|
||||
num_choices -= 1
|
||||
break
|
||||
@@ -292,6 +307,10 @@ class OpenAIServingCompletion:
|
||||
arrival_time = res["metrics"]["arrival_time"] - inference_start_time[idx]
|
||||
|
||||
output = res["outputs"]
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res: Optional[CompletionLogprobs] = None
|
||||
if request.logprobs and output_top_logprobs is not None:
|
||||
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
|
||||
|
||||
choices.append(
|
||||
CompletionResponseStreamChoice(
|
||||
@@ -302,6 +321,7 @@ class OpenAIServingCompletion:
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
arrival_time=arrival_time,
|
||||
logprobs=logprobs_res,
|
||||
)
|
||||
)
|
||||
if res["finished"]:
|
||||
@@ -367,6 +387,7 @@ class OpenAIServingCompletion:
|
||||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
aggregated_logprobs: Optional[CompletionLogprobs] = None
|
||||
|
||||
for idx in range(len(final_res_batch)):
|
||||
final_res = final_res_batch[idx]
|
||||
@@ -376,6 +397,18 @@ class OpenAIServingCompletion:
|
||||
completion_token_ids = completion_batched_token_ids[idx]
|
||||
|
||||
output = final_res["outputs"]
|
||||
output_top_logprobs = output["top_logprobs"]
|
||||
|
||||
if output_top_logprobs is not None:
|
||||
logprobs_res = self._create_completion_logprobs(output_top_logprobs, request.logprobs, 0)
|
||||
if aggregated_logprobs is None:
|
||||
aggregated_logprobs = logprobs_res
|
||||
else:
|
||||
aggregated_logprobs.tokens.extend(logprobs_res.tokens)
|
||||
aggregated_logprobs.token_logprobs.extend(logprobs_res.token_logprobs)
|
||||
aggregated_logprobs.top_logprobs.extend(logprobs_res.top_logprobs)
|
||||
aggregated_logprobs.text_offset.extend(logprobs_res.text_offset)
|
||||
|
||||
if request.echo:
|
||||
assert prompt_text is not None
|
||||
if request.max_tokens == 0:
|
||||
@@ -396,7 +429,7 @@ class OpenAIServingCompletion:
|
||||
completion_token_ids=completion_token_ids if request.return_token_ids else None,
|
||||
reasoning_content=output.get("reasoning_content"),
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
logprobs=None,
|
||||
logprobs=aggregated_logprobs,
|
||||
finish_reason=None,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
@@ -419,3 +452,99 @@ class OpenAIServingCompletion:
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
def _create_completion_logprobs(
|
||||
self,
|
||||
output_top_logprobs,
|
||||
request_logprobs: Optional[int] = None,
|
||||
prompt_text_offset: Optional[int] = None,
|
||||
) -> Optional[CompletionLogprobs]:
|
||||
"""Create OpenAI-style logprobs for completions."""
|
||||
|
||||
# Parameter validation
|
||||
if output_top_logprobs is None or len(output_top_logprobs) < 3 or any(not lst for lst in output_top_logprobs):
|
||||
return None
|
||||
|
||||
logprobs_res: Optional[CompletionLogprobs] = None
|
||||
# Iterate over the top-k candidates for each token
|
||||
for logprob_token_ids, logprobs, sampled_token_ranks in zip(
|
||||
output_top_logprobs[0], output_top_logprobs[1], output_top_logprobs[2]
|
||||
):
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[logprob_token_ids],
|
||||
logprobs=[logprobs],
|
||||
sampled_token_ranks=[sampled_token_ranks],
|
||||
)
|
||||
# Build the logprobs response
|
||||
step_logprobs_res = self._build_logprobs_response(
|
||||
response_logprobs=top_logprobs,
|
||||
request_top_logprobs=request_logprobs,
|
||||
prompt_text_offset=prompt_text_offset,
|
||||
)
|
||||
if logprobs_res is None:
|
||||
logprobs_res = step_logprobs_res
|
||||
else:
|
||||
# Append the new tokens to the existing logprobs response
|
||||
logprobs_res.tokens.extend(step_logprobs_res.tokens)
|
||||
logprobs_res.token_logprobs.extend(step_logprobs_res.token_logprobs)
|
||||
logprobs_res.top_logprobs.extend(step_logprobs_res.top_logprobs)
|
||||
|
||||
return logprobs_res
|
||||
|
||||
def _build_logprobs_response(
|
||||
self,
|
||||
response_logprobs: Optional[LogprobsLists] = None,
|
||||
request_top_logprobs: Optional[int] = None,
|
||||
prompt_text_offset: Optional[int] = None,
|
||||
) -> Optional[CompletionLogprobs]:
|
||||
"""
|
||||
Construct a logprobs response object in line with the OpenAI style.
|
||||
Retain the complete top-k candidates and avoid circular references.
|
||||
"""
|
||||
|
||||
# Parameter validation
|
||||
if response_logprobs is None or request_top_logprobs is None or request_top_logprobs < 0:
|
||||
return None
|
||||
|
||||
try:
|
||||
# The top-k candidates for the current token
|
||||
topk_token_ids = []
|
||||
topk_logprobs = []
|
||||
|
||||
if response_logprobs.logprob_token_ids and len(response_logprobs.logprob_token_ids) > 0:
|
||||
topk_token_ids = response_logprobs.logprob_token_ids[0][: request_top_logprobs + 1]
|
||||
|
||||
if response_logprobs.logprobs and len(response_logprobs.logprobs) > 0:
|
||||
topk_logprobs = response_logprobs.logprobs[0][: request_top_logprobs + 1]
|
||||
|
||||
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
|
||||
tokens = []
|
||||
token_logprobs = []
|
||||
top_logprobs = {}
|
||||
idx = 0
|
||||
for tid, lp in zip(topk_token_ids, topk_logprobs):
|
||||
token_str = self.engine_client.data_processor.process_logprob_response(
|
||||
[tid], clean_up_tokenization_spaces=False
|
||||
)
|
||||
if "\ufffd" in token_str:
|
||||
token_bytes = token_str.encode("utf-8", errors="replace")
|
||||
token_str = "bytes:" + "".join(f"\\x{byte:02x}" for byte in token_bytes)
|
||||
if idx == 0:
|
||||
tokens.append(token_str)
|
||||
token_logprobs.append(lp)
|
||||
else:
|
||||
top_logprobs[token_str] = lp
|
||||
idx += 1
|
||||
|
||||
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
|
||||
# text_offset = prompt_text_offset + len(tokens) - 1
|
||||
return CompletionLogprobs(
|
||||
tokens=tokens,
|
||||
token_logprobs=token_logprobs,
|
||||
top_logprobs=[top_logprobs],
|
||||
# text_offset=[text_offset],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_server_logger.error("Error in _build_logprobs_response: %s", e)
|
||||
return None
|
||||
|
Reference in New Issue
Block a user