mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[LogProbs]Enable prompt logprobs output and modify data transmission method for the online interface. (#5089)
* add prompt logprobs * Merge prompt_logprobs_tensors and prompt_logprobs * fix param check * trigger ci * fix unitest * fix logprobs bug
This commit is contained in:
@@ -21,9 +21,17 @@ import time
|
||||
import uuid
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationInfo, field_validator, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
ValidationInfo,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from fastdeploy.engine.pooling_params import PoolingParams
|
||||
from fastdeploy.worker.output import PromptLogprobs
|
||||
|
||||
|
||||
class InvalidParameterException(Exception):
|
||||
@@ -214,10 +222,12 @@ class ChatCompletionResponseChoice(BaseModel):
|
||||
Chat completion response choice.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
draft_logprobs: Optional[LogProbs] = None
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
|
||||
|
||||
|
||||
@@ -275,10 +285,12 @@ class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
Chat completion response choice for stream response.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
draft_logprobs: Optional[LogProbs] = None
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
|
||||
arrival_time: Optional[float] = None
|
||||
|
||||
@@ -301,6 +313,7 @@ class CompletionResponseChoice(BaseModel):
|
||||
Completion response choice.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
index: int
|
||||
text: str
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
@@ -310,6 +323,7 @@ class CompletionResponseChoice(BaseModel):
|
||||
arrival_time: Optional[float] = None
|
||||
logprobs: Optional[CompletionLogprobs] = None
|
||||
draft_logprobs: Optional[CompletionLogprobs] = None
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
|
||||
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
|
||||
@@ -344,11 +358,13 @@ class CompletionResponseStreamChoice(BaseModel):
|
||||
Completion response choice for stream response.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
index: int
|
||||
text: str
|
||||
arrival_time: float = None
|
||||
logprobs: Optional[CompletionLogprobs] = None
|
||||
draft_logprobs: Optional[CompletionLogprobs] = None
|
||||
prompt_logprobs: Optional[PromptLogprobs] = None
|
||||
prompt_token_ids: Optional[List[int]] = None
|
||||
completion_token_ids: Optional[List[int]] = None
|
||||
prompt_tokens: Optional[str] = None
|
||||
@@ -437,6 +453,7 @@ class CompletionRequest(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
|
||||
logprobs: Optional[int] = None
|
||||
include_draft_logprobs: Optional[bool] = False
|
||||
prompt_logprobs: Optional[int] = None
|
||||
# For logits and logprobs post processing
|
||||
temp_scaled_logprobs: bool = False
|
||||
top_p_normalized_logprobs: bool = False
|
||||
@@ -569,6 +586,18 @@ class CompletionRequest(BaseModel):
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_logprobs(cls, data):
|
||||
if (logprobs := data.get("logprobs")) is not None:
|
||||
if logprobs < -1:
|
||||
raise ValueError("`logprobs` must be a greater than -1.")
|
||||
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if prompt_logprobs < -1:
|
||||
raise ValueError("`prompt_logprobs` must be a greater than -1.")
|
||||
return data
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""
|
||||
@@ -583,6 +612,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = 0
|
||||
prompt_logprobs: Optional[int] = None
|
||||
include_draft_logprobs: Optional[bool] = False
|
||||
|
||||
# For logits and logprobs post processing
|
||||
@@ -651,6 +681,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
|
||||
req_dict["max_tokens"] = self.max_completion_tokens or self.max_tokens
|
||||
req_dict["logprobs"] = self.top_logprobs if self.logprobs else None
|
||||
req_dict["prompt_logprobs"] = self.prompt_logprobs
|
||||
req_dict["temp_scaled_logprobs"] = self.temp_scaled_logprobs
|
||||
req_dict["top_p_normalized_logprobs"] = self.top_p_normalized_logprobs
|
||||
|
||||
@@ -751,12 +782,15 @@ class ChatCompletionRequest(BaseModel):
|
||||
def check_logprobs(cls, data):
|
||||
|
||||
if (top_logprobs := data.get("top_logprobs")) is not None:
|
||||
if top_logprobs < 0:
|
||||
raise ValueError("`top_logprobs` must be a positive value.")
|
||||
if top_logprobs < -1:
|
||||
raise ValueError("`top_logprobs` must be a greater than -1.")
|
||||
|
||||
if top_logprobs > 0 and not data.get("logprobs"):
|
||||
if not data.get("logprobs"):
|
||||
raise ValueError("when using `top_logprobs`, `logprobs` must be set to true.")
|
||||
|
||||
if (prompt_logprobs := data.get("prompt_logprobs")) is not None:
|
||||
if prompt_logprobs < -1:
|
||||
raise ValueError("`prompt_logprobs` must be a greater than -1.")
|
||||
return data
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user