mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Speculative Decoding] Add draft_logprobs Support for Speculative Decode MTP (#4467)
* feat: add draft_logprobs for Speculative Decode MTP * feat: add draft_logprobs for Speculative Decode MTP * feat: add draft_logprobs for Speculative Decode MTP * fix: postprocess for speculative decode * test: test_speculative_decoding_use_logprobs * fix: test_completion_echo * fix test_max_streaming_tokens --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -308,6 +308,7 @@ class CompletionOutput:
|
||||
decode_type: int = 0
|
||||
logprob: Optional[float] = None
|
||||
top_logprobs: Optional[LogprobsLists] = None
|
||||
draft_top_logprobs: Optional[LogprobsLists] = None
|
||||
logprobs: Optional[SampleLogprobs] = None
|
||||
draft_token_ids: list[int] = None
|
||||
text: Optional[str] = None
|
||||
@@ -322,9 +323,9 @@ class CompletionOutput:
|
||||
"index": self.index,
|
||||
"send_idx": self.send_idx,
|
||||
"token_ids": self.token_ids,
|
||||
"decode_type": self.decode_type,
|
||||
"logprob": self.logprob,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"draft_top_logprobs": self.draft_top_logprobs,
|
||||
"logprobs": self.logprobs,
|
||||
"draft_token_ids": self.draft_token_ids,
|
||||
"text": self.text,
|
||||
@@ -350,6 +351,8 @@ class CompletionOutput:
|
||||
f"draft_token_ids={self.draft_token_ids}, "
|
||||
f"reasoning_content={self.reasoning_content!r}, "
|
||||
f"logprobs={self.logprobs}, "
|
||||
f"top_logprobs={self.top_logprobs}, "
|
||||
f"draft_top_logprobs={self.draft_top_logprobs}, "
|
||||
)
|
||||
|
||||
|
||||
@@ -434,6 +437,7 @@ class RequestOutput:
|
||||
request_id: str,
|
||||
prompt: Optional[str] = None,
|
||||
prompt_token_ids: Optional[list[int]] = None,
|
||||
output_type: Optional[int] = 3,
|
||||
outputs: CompletionOutput = None,
|
||||
finished: bool = False,
|
||||
metrics: Optional[RequestMetrics] = None,
|
||||
@@ -444,6 +448,7 @@ class RequestOutput:
|
||||
self.request_id = request_id
|
||||
self.prompt = prompt
|
||||
self.prompt_token_ids = prompt_token_ids
|
||||
self.output_type = output_type
|
||||
self.outputs = outputs
|
||||
self.finished = finished
|
||||
self.metrics = metrics
|
||||
@@ -472,12 +477,21 @@ class RequestOutput:
|
||||
self.outputs.top_logprobs.logprob_token_ids.extend(next_output.outputs.top_logprobs.logprob_token_ids)
|
||||
self.outputs.top_logprobs.logprobs.extend(next_output.outputs.top_logprobs.logprobs)
|
||||
self.outputs.top_logprobs.sampled_token_ranks.extend(next_output.outputs.top_logprobs.sampled_token_ranks)
|
||||
if next_output.outputs.draft_top_logprobs is not None:
|
||||
self.outputs.draft_top_logprobs.logprob_token_ids.extend(
|
||||
next_output.outputs.draft_top_logprobs.logprob_token_ids
|
||||
)
|
||||
self.outputs.draft_top_logprobs.logprobs.extend(next_output.outputs.draft_top_logprobs.logprobs)
|
||||
self.outputs.draft_top_logprobs.sampled_token_ranks.extend(
|
||||
next_output.outputs.draft_top_logprobs.sampled_token_ranks
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"RequestOutput(request_id={self.request_id}, "
|
||||
f"prompt={self.prompt!r}, "
|
||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||
f"output_type={self.output_type}, "
|
||||
f"outputs={self.outputs}, "
|
||||
f"finished={self.finished}, "
|
||||
f"num_cached_tokens={self.num_cached_tokens}, "
|
||||
@@ -498,6 +512,7 @@ class RequestOutput:
|
||||
"request_id": self.request_id,
|
||||
"prompt": self.prompt,
|
||||
"prompt_token_ids": self.prompt_token_ids,
|
||||
"output_type": self.output_type,
|
||||
"outputs": None if self.outputs is None else self.outputs.to_dict(),
|
||||
"metrics": None if self.metrics is None else self.metrics.to_dict(),
|
||||
"finished": self.finished,
|
||||
|
||||
Reference in New Issue
Block a user