[LogProbs]Enable prompt logprobs output and modify data transmission method for the online interface. (#5089)

* add prompt logprobs

* Merge prompt_logprobs_tensors and prompt_logprobs

* fix param check

* trigger ci

* fix unitest

* fix logprobs bug
This commit is contained in:
qwes5s5
2025-12-02 13:49:51 +08:00
committed by GitHub
parent af39819fcd
commit 117980dd4e
27 changed files with 4947 additions and 233 deletions

View File

@@ -21,9 +21,17 @@ import time
import uuid
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
ValidationInfo,
field_validator,
model_validator,
)
from fastdeploy.engine.pooling_params import PoolingParams
from fastdeploy.worker.output import PromptLogprobs
class InvalidParameterException(Exception):
@@ -214,10 +222,12 @@ class ChatCompletionResponseChoice(BaseModel):
Chat completion response choice.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
@@ -275,10 +285,12 @@ class ChatCompletionResponseStreamChoice(BaseModel):
Chat completion response choice for stream response.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
arrival_time: Optional[float] = None
@@ -301,6 +313,7 @@ class CompletionResponseChoice(BaseModel):
Completion response choice.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
index: int
text: str
prompt_token_ids: Optional[List[int]] = None
@@ -310,6 +323,7 @@ class CompletionResponseChoice(BaseModel):
arrival_time: Optional[float] = None
logprobs: Optional[CompletionLogprobs] = None
draft_logprobs: Optional[CompletionLogprobs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
reasoning_content: Optional[str] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
@@ -344,11 +358,13 @@ class CompletionResponseStreamChoice(BaseModel):
Completion response choice for stream response.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
index: int
text: str
arrival_time: float = None
logprobs: Optional[CompletionLogprobs] = None
draft_logprobs: Optional[CompletionLogprobs] = None
prompt_logprobs: Optional[PromptLogprobs] = None
prompt_token_ids: Optional[List[int]] = None
completion_token_ids: Optional[List[int]] = None
prompt_tokens: Optional[str] = None
@@ -437,6 +453,7 @@ class CompletionRequest(BaseModel):
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
logprobs: Optional[int] = None
include_draft_logprobs: Optional[bool] = False
prompt_logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
@@ -569,6 +586,18 @@ class CompletionRequest(BaseModel):
return data
@model_validator(mode="before")
@classmethod
def check_logprobs(cls, data):
if (logprobs := data.get("logprobs")) is not None:
if logprobs < -1:
raise ValueError("`logprobs` must be a greater than -1.")
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if prompt_logprobs < -1:
raise ValueError("`prompt_logprobs` must be a greater than -1.")
return data
class ChatCompletionRequest(BaseModel):
"""
@@ -583,6 +612,7 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
prompt_logprobs: Optional[int] = None
include_draft_logprobs: Optional[bool] = False
# For logits and logprobs post processing
@@ -651,6 +681,7 @@ class ChatCompletionRequest(BaseModel):
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
req_dict["prompt_logprobs"] = self.prompt_logprobs
req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs
req_dict["top_p_normalized_logprobs"] = self.top_p_normalized_logprobs
@@ -751,12 +782,15 @@ class ChatCompletionRequest(BaseModel):
def check_logprobs(cls, data):
if (top_logprobs := data.get("top_logprobs")) is not None:
if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.")
if top_logprobs < -1:
raise ValueError("`top_logprobs` must be a greater than -1.")
if top_logprobs > 0 and not data.get("logprobs"):
if not data.get("logprobs"):
raise ValueError("when using `top_logprobs`, `logprobs` must be set to true.")
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
if prompt_logprobs < -1:
raise ValueError("`prompt_logprobs` must be a greater than -1.")
return data