[Speculative Decoding] Add draft_logprobs Support for Speculative Decode MTP (#4467)

* feat: add draft_logprobs for Speculative Decode MTP

* feat: add draft_logprobs for Speculative Decode MTP

* feat: add draft_logprobs for Speculative Decode MTP

* fix: postprocess for speculative decode

* test: test_speculative_decoding_use_logprobs

* fix: test_completion_echo

* fix test_max_streaming_tokens

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
SunLei
2025-10-21 14:57:50 +08:00
committed by GitHub
parent 775edcc09a
commit ee915220bd
7 changed files with 422 additions and 48 deletions

View File

@@ -205,6 +205,7 @@ class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls", "recover_stop"]]
@@ -265,6 +266,7 @@ class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
logprobs: Optional[LogProbs] = None
draft_logprobs: Optional[LogProbs] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]] = None
arrival_time: Optional[float] = None
@@ -295,6 +297,7 @@ class CompletionResponseChoice(BaseModel):
completion_tokens: Optional[str] = None
arrival_time: Optional[float] = None
logprobs: Optional[CompletionLogprobs] = None
draft_logprobs: Optional[CompletionLogprobs] = None
reasoning_content: Optional[str] = None
finish_reason: Optional[Literal["stop", "length", "tool_calls"]]
tool_calls: Optional[List[DeltaToolCall | ToolCall]] = None
@@ -333,6 +336,7 @@ class CompletionResponseStreamChoice(BaseModel):
text: str
arrival_time: float = None
logprobs: Optional[CompletionLogprobs] = None
draft_logprobs: Optional[CompletionLogprobs] = None
prompt_token_ids: Optional[List[int]] = None
completion_token_ids: Optional[List[int]] = None
prompt_tokens: Optional[str] = None
@@ -420,6 +424,7 @@ class CompletionRequest(BaseModel):
echo: Optional[bool] = False
frequency_penalty: Optional[float] = Field(default=None, ge=-2, le=2)
logprobs: Optional[int] = None
include_draft_logprobs: Optional[bool] = False
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
@@ -555,6 +560,7 @@ class ChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = Field(None, le=2, ge=-2)
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = 0
include_draft_logprobs: Optional[bool] = False
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False