[Feature] Online Chat API Support Return logprobs (#2777)

* online chat support logprobs

* check xpu

* check vl_gpu_model_runner and xpu_model_runner

* get_worker() check platform
This commit is contained in:
chen
2025-07-10 16:33:40 +08:00
committed by GitHub
parent 24f934f1f9
commit d33105baeb
22 changed files with 608 additions and 114 deletions

View File

@@ -122,6 +122,7 @@ class ChatCompletionResponseChoice(BaseModel):
"""
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
@@ -136,6 +137,21 @@ class ChatCompletionResponse(BaseModel):
choices: List[ChatCompletionResponseChoice]
usage: UsageInfo
class LogProbEntry(BaseModel):
"""
Log probability entry.
"""
token: str
logprob: float
bytes: Optional[List[int]] = None
top_logprobs: Optional[List["LogProbEntry"]] = None
class LogProbs(BaseModel):
"""
LogProbs.
"""
content: Optional[List[LogProbEntry]] = None
refusal: Optional[Union[str, None]] = None
class DeltaMessage(BaseModel):
"""
@@ -154,6 +170,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
"""
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
arrival_time: Optional[float] = None
@@ -392,6 +409,8 @@ class ChatCompletionRequest(BaseModel):
tools: Optional[List[ChatCompletionToolsParam]] = None
model: Optional[str] = "default"
frequency_penalty: Optional[float] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
# remove max_tokens when field is removed from OpenAI API
max_tokens: Optional[int] = Field(
default=None,
@@ -434,6 +453,9 @@ class ChatCompletionRequest(BaseModel):
if request_id is not None:
req_dict['request_id'] = request_id
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
if self.metadata is not None:
for key, value in self.metadata.items():
req_dict[key] = value
@@ -505,3 +527,18 @@ class ChatCompletionRequest(BaseModel):
)
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.")
if top_logprobs > 0 and not data.get("logprobs"):
raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true."
)
return data