mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-29 11:02:54 +08:00
feat: add draft_logprobs for Speculative Decode MTP
This commit is contained in:
@@ -109,6 +109,7 @@ class TokenProcessor:
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.prefill_result_status = dict()
|
||||
self._finalizer = weakref.finalize(self, self._cleanup_resources)
|
||||
self._batch_result_buffer = None
|
||||
|
||||
def _cleanup_resources(self):
|
||||
"""Cleaning up shared memory resources"""
|
||||
@@ -165,7 +166,20 @@ class TokenProcessor:
|
||||
try:
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if (
|
||||
self.cfg.parallel_config.enable_expert_parallel
|
||||
and self.cfg.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
if self.use_logprobs:
|
||||
# TODO speculate_get_output_with_topk
|
||||
pass
|
||||
else:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
|
||||
elif self.use_logprobs:
|
||||
# TODO speculate_get_output_with_topk
|
||||
pass
|
||||
else:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
|
||||
@@ -213,7 +227,7 @@ class TokenProcessor:
|
||||
|
||||
self.executor.submit(process_metrics)
|
||||
|
||||
def postprocess(self, batch_result):
|
||||
def postprocess(self, batch_result, mtype=3):
|
||||
"""
|
||||
single post-processing function
|
||||
|
||||
@@ -221,7 +235,21 @@ class TokenProcessor:
|
||||
batch_result (list): batch results
|
||||
"""
|
||||
try:
|
||||
self.cached_generated_tokens.put_results(batch_result)
|
||||
if self.cfg.speculative_config.method and self.cfg.use_logprobs:
|
||||
if mtype == 3: # target
|
||||
self._batch_result_buffer = batch_result
|
||||
elif mtype == 4: # draft
|
||||
target_batch_result = []
|
||||
draft_batch_result = batch_result
|
||||
for target, decode in zip(self._batch_result_buffer, draft_batch_result):
|
||||
target["outputs"]["draft_top_logprobs"] = decode["outputs"]["draft_top_logprobs"]
|
||||
target_batch_result.append(target)
|
||||
self._batch_result_buffer = None
|
||||
self.cached_generated_tokens.put_results(target_batch_result)
|
||||
else:
|
||||
self.cached_generated_tokens.put_results(batch_result)
|
||||
else:
|
||||
self.cached_generated_tokens.put_results(batch_result)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
@@ -302,9 +330,19 @@ class TokenProcessor:
|
||||
tokens = self.output_tokens.numpy()
|
||||
scores = None
|
||||
ranks = None
|
||||
# target:3, draft:4
|
||||
mtype = 3
|
||||
if self.cfg.speculative_config.method:
|
||||
batch = self.output_tokens[1]
|
||||
accept_num = tokens[2 : batch + 2]
|
||||
if self.use_logprobs:
|
||||
mtype = self.output_tokens[1, 0]
|
||||
batch = self.output_tokens[2, 0]
|
||||
accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]]
|
||||
tokens = tokens[3 + batch : 3 + batch + batch * (K + 1) * MAX_DRAFT_TOKENS].reshape(
|
||||
[batch, K + 1, MAX_DRAFT_TOKENS]
|
||||
)
|
||||
else:
|
||||
batch = self.output_tokens[1]
|
||||
accept_num = tokens[2 : batch + 2]
|
||||
self._record_speculative_decoding_mertics(accept_num)
|
||||
elif self.use_logprobs:
|
||||
batch = self.output_tokens[1, 0]
|
||||
@@ -332,19 +370,24 @@ class TokenProcessor:
|
||||
|
||||
task_id = task.request_id
|
||||
if self.cfg.speculative_config.method:
|
||||
token_ids = tokens[
|
||||
2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS : 2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if len(token_ids) == 0 or token_ids[-1] <= 0:
|
||||
if envs.ENABLE_V1_KVCACHE_SCHEDULER:
|
||||
if task_id in self.resource_manager.to_be_rescheduled_request_id_set:
|
||||
self.resource_manager.reschedule_preempt_task(task_id)
|
||||
continue
|
||||
if accept_num[i] == -3:
|
||||
recovery_stop = True
|
||||
if recovery_stop:
|
||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||
token_ids = [RECOVERY_STOP_SIGNAL]
|
||||
elif self.use_logprobs:
|
||||
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
|
||||
else:
|
||||
token_ids = tokens[
|
||||
2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS : 2
|
||||
+ SPECULATE_MAX_BSZ
|
||||
+ i * MAX_DRAFT_TOKENS
|
||||
+ accept_num[i]
|
||||
].tolist()
|
||||
if (not recovery_stop) and (len(token_ids) == 0 or token_ids[-1] <= 0):
|
||||
continue
|
||||
else:
|
||||
token_id = int(tokens[i, 0])
|
||||
token_ids = [token_id]
|
||||
@@ -387,6 +430,7 @@ class TokenProcessor:
|
||||
self._record_metrics(task, current_time, token_ids)
|
||||
result = RequestOutput(
|
||||
request_id=task_id,
|
||||
output_type=mtype,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task_id],
|
||||
@@ -412,16 +456,36 @@ class TokenProcessor:
|
||||
result.outputs.token_ids.append(token_id)
|
||||
task.output_token_ids.append(token_id)
|
||||
if self.use_logprobs:
|
||||
# TODO 投机解码场景兼容支持
|
||||
result.outputs.logprob = float(scores[i, 0])
|
||||
# Construct top_logprobs
|
||||
topk_token_ids = tokens[i, :].tolist()
|
||||
topk_logprobs = scores[i, :].tolist()
|
||||
sampled_rank = ranks[i].item()
|
||||
result.outputs.top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank],
|
||||
)
|
||||
|
||||
if mtype == 3: # top_logprobs
|
||||
if result.outputs.top_logprobs is None:
|
||||
result.outputs.top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank],
|
||||
)
|
||||
else:
|
||||
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
|
||||
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
|
||||
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
|
||||
elif mtype == 4: # draft_top_logprobs
|
||||
if result.outputs.draft_top_logprobs is None:
|
||||
result.outputs.draft_top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank],
|
||||
)
|
||||
else:
|
||||
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
|
||||
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
|
||||
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
|
||||
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
result.finished = True
|
||||
if recovery_stop:
|
||||
@@ -442,7 +506,7 @@ class TokenProcessor:
|
||||
if not is_prefill or self.cfg.scheduler_config.name == "splitwise":
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result)
|
||||
self.postprocess(batch_result, mtype)
|
||||
|
||||
def _record_metrics(self, task, current_time, token_ids):
|
||||
"""Record all metrics for a task"""
|
||||
|
||||
Reference in New Issue
Block a user