feat: add draft_logprobs for Speculative Decode MTP

This commit is contained in:
sunlei1024
2025-09-26 13:07:48 +08:00
committed by Deleter-D
parent aed79aec4f
commit d5a3c5c933
6 changed files with 293 additions and 24 deletions

View File

@@ -109,6 +109,7 @@ class TokenProcessor:
self.executor = ThreadPoolExecutor(max_workers=1)
self.prefill_result_status = dict()
self._finalizer = weakref.finalize(self, self._cleanup_resources)
self._batch_result_buffer = None
def _cleanup_resources(self):
"""Cleaning up shared memory resources"""
@@ -165,7 +166,20 @@ class TokenProcessor:
try:
is_blocking = True
if self.speculative_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if (
self.cfg.parallel_config.enable_expert_parallel
and self.cfg.parallel_config.data_parallel_size > 1
):
if self.use_logprobs:
# TODO speculate_get_output_with_topk
pass
else:
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
elif self.use_logprobs:
# TODO speculate_get_output_with_topk
pass
else:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2:
continue
@@ -213,7 +227,7 @@ class TokenProcessor:
self.executor.submit(process_metrics)
def postprocess(self, batch_result):
def postprocess(self, batch_result, mtype=3):
"""
single post-processing function
@@ -221,7 +235,21 @@ class TokenProcessor:
batch_result (list): batch results
"""
try:
self.cached_generated_tokens.put_results(batch_result)
if self.cfg.speculative_config.method and self.cfg.use_logprobs:
if mtype == 3: # target
self._batch_result_buffer = batch_result
elif mtype == 4: # draft
target_batch_result = []
draft_batch_result = batch_result
for target, decode in zip(self._batch_result_buffer, draft_batch_result):
target["outputs"]["draft_top_logprobs"] = decode["outputs"]["draft_top_logprobs"]
target_batch_result.append(target)
self._batch_result_buffer = None
self.cached_generated_tokens.put_results(target_batch_result)
else:
self.cached_generated_tokens.put_results(batch_result)
else:
self.cached_generated_tokens.put_results(batch_result)
except Exception as e:
llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}")
@@ -302,9 +330,19 @@ class TokenProcessor:
tokens = self.output_tokens.numpy()
scores = None
ranks = None
# target:3, draft:4
mtype = 3
if self.cfg.speculative_config.method:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
if self.use_logprobs:
mtype = self.output_tokens[1, 0]
batch = self.output_tokens[2, 0]
accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]]
tokens = tokens[3 + batch : 3 + batch + batch * (K + 1) * MAX_DRAFT_TOKENS].reshape(
[batch, K + 1, MAX_DRAFT_TOKENS]
)
else:
batch = self.output_tokens[1]
accept_num = tokens[2 : batch + 2]
self._record_speculative_decoding_mertics(accept_num)
elif self.use_logprobs:
batch = self.output_tokens[1, 0]
@@ -332,19 +370,24 @@ class TokenProcessor:
task_id = task.request_id
if self.cfg.speculative_config.method:
token_ids = tokens[
2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS : 2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if len(token_ids) == 0 or token_ids[-1] <= 0:
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
self.resource_manager.reschedule_preempt_task(task_id)
continue
if accept_num[i] == -3:
recovery_stop = True
if recovery_stop:
llm_logger.info(f"recovery stop signal found at task {task_id}")
token_ids = [RECOVERY_STOP_SIGNAL]
elif self.use_logprobs:
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
else:
token_ids = tokens[
2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS : 2
+ SPECULATE_MAX_BSZ
+ i * MAX_DRAFT_TOKENS
+ accept_num[i]
].tolist()
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
continue
else:
token_id = int(tokens[i, 0])
token_ids = [token_id]
@@ -387,6 +430,7 @@ class TokenProcessor:
self._record_metrics(task, current_time, token_ids)
result = RequestOutput(
request_id=task_id,
output_type=mtype,
outputs=CompletionOutput(
index=i,
send_idx=self.tokens_counter[task_id],
@@ -412,16 +456,36 @@ class TokenProcessor:
result.outputs.token_ids.append(token_id)
task.output_token_ids.append(token_id)
if self.use_logprobs:
# TODO 投机解码场景兼容支持
result.outputs.logprob = float(scores[i, 0])
# Construct top_logprobs
topk_token_ids = tokens[i, :].tolist()
topk_logprobs = scores[i, :].tolist()
sampled_rank = ranks[i].item()
result.outputs.top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
if mtype == 3: # top_logprobs
if result.outputs.top_logprobs is None:
result.outputs.top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
else:
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
elif mtype == 4: # draft_top_logprobs
if result.outputs.draft_top_logprobs is None:
result.outputs.draft_top_logprobs = LogprobsLists(
logprob_token_ids=[topk_token_ids],
logprobs=[topk_logprobs],
sampled_token_ranks=[sampled_rank],
)
else:
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
result.finished = True
if recovery_stop:
@@ -442,7 +506,7 @@ class TokenProcessor:
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
batch_result.append(result)
self.postprocess(batch_result)
self.postprocess(batch_result, mtype)
def _record_metrics(self, task, current_time, token_ids):
"""Record all metrics for a task"""