mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Feature] Support return logprob of generated tokens (#2784)
* online chat support logprobs * check xpu * check vl_gpu_model_runner * only cuda support logprob * get_worker() check platform --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -15,34 +15,23 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import aiozmq
|
||||
from aiozmq import zmq
|
||||
import json
|
||||
import time
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Callable, Optional, Union, List
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import aiozmq
|
||||
from aiozmq import zmq
|
||||
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel
|
||||
from fastdeploy.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest,
|
||||
DeltaMessage,
|
||||
ChatCompletionResponseChoice,
|
||||
ChatCompletionStreamResponse,
|
||||
ChatCompletionResponseStreamChoice,
|
||||
ChatMessage,
|
||||
UsageInfo,
|
||||
PromptTokenUsageInfo,
|
||||
ChatCompletionResponse,
|
||||
ErrorResponse,
|
||||
)
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
LogProbEntry, LogProbs, PromptTokenUsageInfo, UsageInfo)
|
||||
from fastdeploy.metrics.work_metrics import work_process_metrics
|
||||
|
||||
from fastdeploy.utils import api_server_logger
|
||||
|
||||
from fastdeploy.engine.request import RequestOutput
|
||||
|
||||
from fastdeploy.worker.output import LogprobsLists
|
||||
|
||||
|
||||
class OpenAIServingChat:
|
||||
@@ -157,7 +146,7 @@ class OpenAIServingChat:
|
||||
current_waiting_time = 0
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
|
||||
|
||||
res = json.loads(raw_data[-1].decode('utf-8'))
|
||||
if res.get("error_code", 200) != 200:
|
||||
raise ValueError("{}".format(res["error_msg"]))
|
||||
@@ -200,6 +189,18 @@ class OpenAIServingChat:
|
||||
|
||||
output = res["outputs"]
|
||||
delta_text = output["text"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
logprobs_res = None
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
|
||||
previous_num_tokens += len(output["token_ids"])
|
||||
delta_message = DeltaMessage(content=delta_text, reasoning_content=output.get("reasoning_content"), \
|
||||
@@ -208,6 +209,7 @@ class OpenAIServingChat:
|
||||
choice = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=delta_message,
|
||||
logprobs=logprobs_res,
|
||||
arrival_time=arrival_time
|
||||
)
|
||||
if res["finished"]:
|
||||
@@ -220,7 +222,7 @@ class OpenAIServingChat:
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
|
||||
if res.get("error_msg") is not None and "Recover" in res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
|
||||
@@ -286,6 +288,7 @@ class OpenAIServingChat:
|
||||
final_res = None
|
||||
previous_num_tokens = 0
|
||||
current_waiting_time = 0
|
||||
logprob_contents = []
|
||||
while True:
|
||||
try:
|
||||
raw_data = await asyncio.wait_for(dealer.read(), timeout=10)
|
||||
@@ -310,6 +313,21 @@ class OpenAIServingChat:
|
||||
data, stream=False, enable_thinking=enable_thinking)
|
||||
# api_server_logger.debug(f"Client {request_id} received: {data}")
|
||||
previous_num_tokens += len(data["outputs"]["token_ids"])
|
||||
# The logprob for handling the response
|
||||
output = data["outputs"]
|
||||
raw_top_logprobs = output["top_logprobs"]
|
||||
if raw_top_logprobs is not None:
|
||||
top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=raw_top_logprobs[0],
|
||||
logprobs=raw_top_logprobs[1],
|
||||
sampled_token_ranks=raw_top_logprobs[2],
|
||||
)
|
||||
logprobs_res = self.build_logprobs_response(
|
||||
logprobs=top_logprobs,
|
||||
request_top_logprobs=request.top_logprobs,
|
||||
)
|
||||
if logprobs_res and logprobs_res.content is not None:
|
||||
logprob_contents.extend(logprobs_res.content)
|
||||
if data["finished"]:
|
||||
final_res = data
|
||||
break
|
||||
@@ -325,10 +343,16 @@ class OpenAIServingChat:
|
||||
tool_calls=output.get("tool_call_content"),
|
||||
token_ids=output.get("token_ids")
|
||||
)
|
||||
logprobs_full_res = None
|
||||
if logprob_contents:
|
||||
logprobs_full_res = LogProbs(
|
||||
content=logprob_contents
|
||||
)
|
||||
|
||||
choice = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=message,
|
||||
logprobs=logprobs_full_res,
|
||||
finish_reason=None
|
||||
)
|
||||
if request.max_tokens is None or previous_num_tokens != request.max_tokens:
|
||||
@@ -338,7 +362,7 @@ class OpenAIServingChat:
|
||||
choice.finish_reason = "tool_calls"
|
||||
else:
|
||||
choice.finish_reason = "length"
|
||||
|
||||
|
||||
if final_res.get("error_msg") is not None and "Recover" in final_res["error_msg"]:
|
||||
choice.finish_reason = "recover_stop"
|
||||
choices.append(choice)
|
||||
@@ -359,3 +383,54 @@ class OpenAIServingChat:
|
||||
choices=choices,
|
||||
usage=usage
|
||||
)
|
||||
|
||||
def build_logprobs_response(
|
||||
self,
|
||||
logprobs: Optional[LogprobsLists],
|
||||
request_top_logprobs: int,
|
||||
) -> Optional[LogProbs]:
|
||||
"""
|
||||
Construct a logprobs response object in line with the OpenAI style.
|
||||
Retain the complete top-k candidates and avoid circular references.
|
||||
"""
|
||||
|
||||
# Parameter validation
|
||||
if (
|
||||
logprobs is None
|
||||
or request_top_logprobs is None
|
||||
or request_top_logprobs <= 0
|
||||
or len(logprobs.logprob_token_ids) == 0
|
||||
):
|
||||
return None
|
||||
|
||||
try:
|
||||
# The top-k candidates for the current token
|
||||
topk_token_ids = logprobs.logprob_token_ids[0][:request_top_logprobs + 1]
|
||||
topk_logprobs = logprobs.logprobs[0][:request_top_logprobs + 1]
|
||||
|
||||
# Construct the candidate token structure (LogProbEntry) of topk
|
||||
top_logprob_entries: List[LogProbEntry] = []
|
||||
for tid, lp in zip(topk_token_ids, topk_logprobs):
|
||||
token_str = self.engine_client.data_processor.process_logprob_response([tid],
|
||||
clean_up_tokenization_spaces=False)
|
||||
# token_bytes = token_str.encode("utf-8", errors="replace")
|
||||
entry = LogProbEntry(
|
||||
token=token_str,
|
||||
logprob=lp,
|
||||
# bytes=list(token_bytes)
|
||||
)
|
||||
top_logprob_entries.append(entry)
|
||||
# Construct the sampled token object (avoid sharing references with top_logprob_entries)
|
||||
sampled_entry = LogProbEntry(
|
||||
token=top_logprob_entries[0].token,
|
||||
logprob=top_logprob_entries[0].logprob,
|
||||
bytes=top_logprob_entries[0].bytes,
|
||||
top_logprobs=top_logprob_entries[1:] # Here are the complete topk candidates
|
||||
)
|
||||
|
||||
return LogProbs(content=[sampled_entry])
|
||||
|
||||
except Exception as e:
|
||||
api_server_logger.error("Error in build_logprobs_response: %s", e)
|
||||
api_server_logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
Reference in New Issue
Block a user