[Speculative Decoding] Add draft_logprobs Support for Speculative Decode MTP (#4467)

* feat: add draft_logprobs for Speculative Decode MTP

* feat: add draft_logprobs for Speculative Decode MTP

* feat: add draft_logprobs for Speculative Decode MTP

* fix: postprocess for speculative decode

* test: test_speculative_decoding_use_logprobs

* fix: test_completion_echo

* fix test_max_streaming_tokens

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
SunLei
2025-10-21 14:57:50 +08:00
committed by GitHub
parent 775edcc09a
commit ee915220bd
7 changed files with 422 additions and 48 deletions

View File

@@ -308,6 +308,7 @@ class CompletionOutput:
decode_type: int = 0
logprob: Optional[float] = None
top_logprobs: Optional[LogprobsLists] = None
draft_top_logprobs: Optional[LogprobsLists] = None
logprobs: Optional[SampleLogprobs] = None
draft_token_ids: list[int] = None
text: Optional[str] = None
@@ -322,9 +323,9 @@ class CompletionOutput:
"index": self.index,
"send_idx": self.send_idx,
"token_ids": self.token_ids,
"decode_type": self.decode_type,
"logprob": self.logprob,
"top_logprobs": self.top_logprobs,
"draft_top_logprobs": self.draft_top_logprobs,
"logprobs": self.logprobs,
"draft_token_ids": self.draft_token_ids,
"text": self.text,
@@ -350,6 +351,8 @@ class CompletionOutput:
f"draft_token_ids={self.draft_token_ids}, "
f"reasoning_content={self.reasoning_content!r}, "
f"logprobs={self.logprobs}, "
f"top_logprobs={self.top_logprobs}, "
f"draft_top_logprobs={self.draft_top_logprobs}, "
)
@@ -434,6 +437,7 @@ class RequestOutput:
request_id: str,
prompt: Optional[str] = None,
prompt_token_ids: Optional[list[int]] = None,
output_type: Optional[int] = 3,
outputs: CompletionOutput = None,
finished: bool = False,
metrics: Optional[RequestMetrics] = None,
@@ -444,6 +448,7 @@ class RequestOutput:
self.request_id = request_id
self.prompt = prompt
self.prompt_token_ids = prompt_token_ids
self.output_type = output_type
self.outputs = outputs
self.finished = finished
self.metrics = metrics
@@ -472,12 +477,21 @@ class RequestOutput:
self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids)
self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs)
self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks)
if next_output.outputs.draft_top_logprobs is not None:
self.outputs.draft_top_logprobs.logprob_token_ids.extend(
next_output.outputs.draft_top_logprobs.logprob_token_ids
)
self.outputs.draft_top_logprobs.logprobs.extend(next_output.outputs.draft_top_logprobs.logprobs)
self.outputs.draft_top_logprobs.sampled_token_ranks.extend(
next_output.outputs.draft_top_logprobs.sampled_token_ranks
)
def __repr__(self) -> str:
return (
f"RequestOutput(request_id={self.request_id}, "
f"prompt={self.prompt!r}, "
f"prompt_token_ids={self.prompt_token_ids}, "
f"output_type={self.output_type}, "
f"outputs={self.outputs}, "
f"finished={self.finished}, "
f"num_cached_tokens={self.num_cached_tokens}, "
@@ -498,6 +512,7 @@ class RequestOutput:
"request_id": self.request_id,
"prompt": self.prompt,
"prompt_token_ids": self.prompt_token_ids,
"output_type": self.output_type,
"outputs": None if self.outputs is None else self.outputs.to_dict(),
"metrics": None if self.metrics is None else self.metrics.to_dict(),
"finished": self.finished,