[Speculative Decoding] Add draft_logprobs Support for Speculative Decode MTP (#4467)

* feat: add draft_logprobs for Speculative Decode MTP

* feat: add draft_logprobs for Speculative Decode MTP

* feat: add draft_logprobs for Speculative Decode MTP

* fix: postprocess for speculative decode

* test: test_speculative_decoding_use_logprobs

* fix: test_completion_echo

* fix test_max_streaming_tokens

---------

Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
SunLei
2025-10-21 14:57:50 +08:00
committed by GitHub
parent 775edcc09a
commit ee915220bd
7 changed files with 422 additions and 48 deletions

View File

@@ -22,6 +22,7 @@ import traceback
import weakref
from collections import Counter
from concurrent.futures import ThreadPoolExecutor
from typing import List
import numpy as np
import paddle
@@ -67,11 +68,20 @@ class TokenProcessor:
self.use_logprobs = self.cfg.model_config.enable_logprob
if self.speculative_decoding:
self.output_tokens = paddle.full(
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2],
fill_value=2,
dtype="int64",
)
if self.use_logprobs:
self.output_tokens = paddle.full(
shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], fill_value=2, dtype="int64"
)
self.output_scores = paddle.full(
shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], fill_value=0.0, dtype="float32"
)
self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64")
else:
self.output_tokens = paddle.full(
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2],
fill_value=2,
dtype="int64",
)
elif self.use_logprobs:
self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
@@ -107,6 +117,7 @@ class TokenProcessor:
self.executor = ThreadPoolExecutor(max_workers=1)
self.prefill_result_status = dict()
self._finalizer = weakref.finalize(self, self._cleanup_resources)
self._batch_result_buffer = None
def _cleanup_resources(self):
"""Cleaning up shared memory resources"""
@@ -312,6 +323,7 @@ class TokenProcessor:
get_output_ep,
get_output_topk,
speculate_get_output,
speculate_get_output_topk,
)
rank_id = self.cfg.parallel_config.local_data_parallel_id
@@ -319,15 +331,27 @@ class TokenProcessor:
try:
is_blocking = True
if self.speculative_decoding:
if (
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
if self.use_logprobs:
speculate_get_output_topk(
self.output_tokens,
self.output_scores,
self.output_ranks,
K,
rank_id,
is_blocking,
)
if self.output_tokens[0, 0] == -2:
continue
else:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2:
continue
if (
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
else:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2:
continue
else:
if self.use_logprobs:
get_output_topk(
@@ -372,7 +396,7 @@ class TokenProcessor:
self.executor.submit(process_metrics)
def postprocess(self, batch_result):
def postprocess(self, batch_result: List[RequestOutput], mtype=3):
"""
single post-processing function
@@ -380,7 +404,28 @@ class TokenProcessor:
batch_result (list): batch results
"""
try:
self.cached_generated_tokens.put_results(batch_result)
if self.cfg.speculative_config.method and self.use_logprobs:
if mtype == 3: # target
finished_batch_result, unfinished_batch_result = [], []
for r in batch_result:
(finished_batch_result if r.finished else unfinished_batch_result).append(r)
if finished_batch_result:
self.cached_generated_tokens.put_results(batch_result)
else:
self._batch_result_buffer = unfinished_batch_result
elif mtype == 4: # draft
target_batch_result = []
draft_batch_result = batch_result
if self._batch_result_buffer is not None:
for target, decode in zip(self._batch_result_buffer, draft_batch_result):
target.outputs.draft_top_logprobs = decode.outputs.draft_top_logprobs
target_batch_result.append(target)
self._batch_result_buffer = None
self.cached_generated_tokens.put_results(target_batch_result)
else:
self.cached_generated_tokens.put_results(batch_result)
else:
self.cached_generated_tokens.put_results(batch_result)
except Exception as e:
llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}")
@@ -471,9 +516,25 @@ class TokenProcessor:
tokens = self.output_tokens.numpy()
scores = None
ranks = None
# target:3, draft:4
mtype = 3
if self.cfg.speculative_config.method:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
if self.use_logprobs:
mtype = int(self.output_tokens[1, 0].item())
batch = self.output_tokens[2, 0]
accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]]
tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape(
[batch, MAX_DRAFT_TOKENS, K + 1]
)
scores = (
self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)]
.numpy()
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
)
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])
else:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
self._record_speculative_decoding_mertics(accept_num)
elif self.use_logprobs:
batch = self.output_tokens[1, 0]
@@ -501,6 +562,8 @@ class TokenProcessor:
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
token_ids = [RECOVERY_STOP_SIGNAL]
elif self.use_logprobs:
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
else:
token_ids = tokens[
2
@@ -556,6 +619,7 @@ class TokenProcessor:
self._record_metrics(task, current_time, token_ids)
result = RequestOutput(
request_id=task_id,
output_type=mtype,
outputs=CompletionOutput(
index=i,
send_idx=self.tokens_counter[task_id],
@@ -575,29 +639,54 @@ class TokenProcessor:
if is_prefill and len(token_ids) > 1:
result.outputs.draft_token_ids = copy.deepcopy(token_ids)
for token_id in token_ids:
for batch_token_index in range(len(token_ids)):
token_id = token_ids[batch_token_index]
self.tokens_counter[task_id] += 1
if token_id != RECOVERY_STOP_SIGNAL:
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
if self.use_logprobs:
result.outputs.logprob = float(scores[i, 0])
# Construct top_logprobs
topk_token_ids = tokens[i, :].tolist()
topk_logprobs = scores[i, :].tolist()
sampled_rank = ranks[i].item()
result.outputs.top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
if self.cfg.speculative_config.method:
result.outputs.logprob = float(scores[i, batch_token_index, 0])
topk_token_ids = tokens[i, batch_token_index, :].tolist()
topk_logprobs = scores[i, batch_token_index, :].tolist()
sampled_rank = ranks[i, batch_token_index].item()
else:
result.outputs.logprob = float(scores[i, 0])
topk_token_ids = tokens[i, :].tolist()
topk_logprobs = scores[i, :].tolist()
sampled_rank = ranks[i].item()
if mtype == 3: # top_logprobs
if result.outputs.top_logprobs is None:
result.outputs.top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
else:
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
elif mtype == 4: # draft_top_logprobs
if result.outputs.draft_top_logprobs is None:
result.outputs.draft_top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
else:
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
if mtype == 3 and (token_id in task.eos_token_ids or is_prefill or recovery_stop):
result.finished = True
if recovery_stop:
result.error_msg = "Recover is not supported, the result is incomplete!"
llm_logger.info(
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
f"Request: {task_id} finished, number of "
f"generated tokens: {self.tokens_counter[task_id]}, token_id:{token_id},is_prefill:{is_prefill},recovery_stop:{recovery_stop}"
)
llm_logger.info(
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
@@ -616,7 +705,7 @@ class TokenProcessor:
):
batch_result.append(result)
self.postprocess(batch_result)
self.postprocess(batch_result, mtype)
def _record_metrics(self, task, current_time, token_ids):
"""Record all metrics for a task"""