[Feature] Support return logprob of generated tokens (#2784)

* online chat support logprobs

* check xpu

* check vl_gpu_model_runner

* only cuda support logprob

* get_worker() check platform

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
chen
2025-07-10 15:47:42 +08:00
committed by GitHub
parent 39d2a1de46
commit 823a47e64a
21 changed files with 592 additions and 105 deletions

View File

@@ -15,34 +15,23 @@
"""
import asyncio
import aiozmq
from aiozmq import zmq
import json
import time
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Callable, Optional, Union, List
import traceback
import uuid
from typing import List, Optional
import aiozmq
from aiozmq import zmq
from fastapi import Request
from pydantic import BaseModel
from fastdeploy.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaMessage,
ChatCompletionResponseChoice,
ChatCompletionStreamResponse,
ChatCompletionResponseStreamChoice,
ChatMessage,
UsageInfo,
PromptTokenUsageInfo,
ChatCompletionResponse,
ErrorResponse,
)
ChatCompletionRequest, ChatCompletionResponse,
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
LogProbEntry, LogProbs, PromptTokenUsageInfo, UsageInfo)
from fastdeploy.metrics.work_metrics import work_process_metrics
from fastdeploy.utils import api_server_logger
from fastdeploy.engine.request import RequestOutput
from fastdeploy.worker.output import LogprobsLists
class OpenAIServingChat:
@@ -157,7 +146,7 @@ class OpenAIServingChat:
current_waiting_time = 0
await asyncio.sleep(0.1)
continue
res = json.loads(raw_data[-1].decode('utf-8'))
if res.get("error_code", 200) != 200:
raise ValueError("{}".format(res["error_msg"]))
@@ -200,6 +189,18 @@ class OpenAIServingChat:
output = res["outputs"]
delta_text = output["text"]
raw_top_logprobs = output["top_logprobs"]
logprobs_res = None
if raw_top_logprobs is not None:
top_logprobs = LogprobsLists(
logprob_token_ids=raw_top_logprobs[0],
logprobs=raw_top_logprobs[1],
sampled_token_ranks=raw_top_logprobs[2],
)
logprobs_res = self.build_logprobs_response(
logprobs=top_logprobs,
request_top_logprobs=request.top_logprobs,
)
previous_num_tokens += len(output["token_ids"])
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
@@ -208,6 +209,7 @@ class OpenAIServingChat:
choice = ChatCompletionResponseStreamChoice(
index=0,
delta=delta_message,
logprobs=logprobs_res,
arrival_time=arrival_time
)
if res["finished"]:
@@ -220,7 +222,7 @@ class OpenAIServingChat:
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = "length"
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
choice.finish_reason = "recover_stop"
@@ -286,6 +288,7 @@ class OpenAIServingChat:
final_res = None
previous_num_tokens = 0
current_waiting_time = 0
logprob_contents = []
while True:
try:
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
@@ -310,6 +313,21 @@ class OpenAIServingChat:
data, stream=False, enable_thinking=enable_thinking)
# api_server_logger.debug(f"Client {request_id} received: {data}")
previous_num_tokens += len(data["outputs"]["token_ids"])
# The logprob for handling the response
output = data["outputs"]
raw_top_logprobs = output["top_logprobs"]
if raw_top_logprobs is not None:
top_logprobs = LogprobsLists(
logprob_token_ids=raw_top_logprobs[0],
logprobs=raw_top_logprobs[1],
sampled_token_ranks=raw_top_logprobs[2],
)
logprobs_res = self.build_logprobs_response(
logprobs=top_logprobs,
request_top_logprobs=request.top_logprobs,
)
if logprobs_res and logprobs_res.content is not None:
logprob_contents.extend(logprobs_res.content)
if data["finished"]:
final_res = data
break
@@ -325,10 +343,16 @@ class OpenAIServingChat:
tool_calls=output.get("tool_call_content"),
token_ids=output.get("token_ids")
)
logprobs_full_res = None
if logprob_contents:
logprobs_full_res = LogProbs(
content=logprob_contents
)
choice = ChatCompletionResponseChoice(
index=0,
message=message,
logprobs=logprobs_full_res,
finish_reason=None
)
if request.max_tokens is None or previous_num_tokens != request.max_tokens:
@@ -338,7 +362,7 @@ class OpenAIServingChat:
choice.finish_reason = "tool_calls"
else:
choice.finish_reason = "length"
if final_res.get("error_msg") is not None and "Recover" in final_res["error_msg"]:
choice.finish_reason = "recover_stop"
choices.append(choice)
@@ -359,3 +383,54 @@ class OpenAIServingChat:
choices=choices,
usage=usage
)
def build_logprobs_response(
self,
logprobs: Optional[LogprobsLists],
request_top_logprobs: int,
) -> Optional[LogProbs]:
"""
Construct a logprobs response object in line with the OpenAI style.
Retain the complete top-k candidates and avoid circular references.
"""
# Parameter validation
if (
logprobs is None
or request_top_logprobs is None
or request_top_logprobs <= 0
or len(logprobs.logprob_token_ids) == 0
):
return None
try:
# The top-k candidates for the current token
topk_token_ids = logprobs.logprob_token_ids[0][:request_top_logprobs + 1]
topk_logprobs = logprobs.logprobs[0][:request_top_logprobs + 1]
# Construct the candidate token structure (LogProbEntry) of topk
top_logprob_entries: List[LogProbEntry] = []
for tid, lp in zip(topk_token_ids, topk_logprobs):
token_str = self.engine_client.data_processor.process_logprob_response([tid],
clean_up_tokenization_spaces=False)
# token_bytes = token_str.encode("utf-8", errors="replace")
entry = LogProbEntry(
token=token_str,
logprob=lp,
# bytes=list(token_bytes)
)
top_logprob_entries.append(entry)
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
sampled_entry = LogProbEntry(
token=top_logprob_entries[0].token,
logprob=top_logprob_entries[0].logprob,
bytes=top_logprob_entries[0].bytes,
top_logprobs=top_logprob_entries[1:] # Here are the complete topk candidates
)
return LogProbs(content=[sampled_entry])
except Exception as e:
api_server_logger.error("Error in build_logprobs_response: %s", e)
api_server_logger.error(traceback.format_exc())
return None