mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 16:22:57 +08:00
[Feature] Online Chat API Support Return logprobs (#2777)
* online chat support logprobs * check xpu * check vl_gpu_model_runner and xpu_model_runner * get_worker() check platform
This commit is contained in:
@@ -122,6 +122,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
"""
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
|
||||
|
||||
|
||||
@@ -136,6 +137,21 @@ class ChatCompletionResponse(BaseModel):
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
class LogProbEntry(BaseModel):
|
||||
"""
|
||||
Log probability entry.
|
||||
"""
|
||||
token: str
|
||||
logprob: float
|
||||
bytes: Optional[List[int]] = None
|
||||
top_logprobs: Optional[List["LogProbEntry"]] = None
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
"""
|
||||
LogProbs.
|
||||
"""
|
||||
content: Optional[List[LogProbEntry]] = None
|
||||
refusal: Optional[Union[str, None]] = None
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
"""
|
||||
@@ -154,6 +170,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
"""
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||
arrival_time: Optional[float] = None
|
||||
|
||||
@@ -392,6 +409,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
tools: Optional[List[ChatCompletionToolsParam]] = None
|
||||
model: Optional[str] = "default"
|
||||
frequency_penalty: Optional[float] = None
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = 0
|
||||
# remove max_tokens when field is removed from OpenAI API
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None,
|
||||
@@ -434,6 +453,9 @@ class ChatCompletionRequest(BaseModel):
|
||||
if request_id is not None:
|
||||
req_dict['request_id'] = request_id
|
||||
|
||||
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
|
||||
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
|
||||
|
||||
if self.metadata is not None:
|
||||
for key, value in self.metadata.items():
|
||||
req_dict[key] = value
|
||||
@@ -505,3 +527,18 @@ class ChatCompletionRequest(BaseModel):
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0:
|
||||
raise ValueError("`top_logprobs` must be a positive value.")
|
||||
|
||||
if top_logprobs > 0 and not data.get("logprobs"):
|
||||
raise ValueError(
|
||||
"when using `top_logprobs`, `logprobs` must be set to true."
|
||||
)
|
||||
|
||||
return data
|
||||
|
Reference in New Issue
Block a user