mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-02 04:43:27 +08:00
[Speculative Decoding] Add draft_logprobs Support for Speculative Decode MTP (#4467)
* feat: add draft_logprobs for Speculative Decode MTP * feat: add draft_logprobs for Speculative Decode MTP * feat: add draft_logprobs for Speculative Decode MTP * fix: postprocess for speculative decode * test: test_speculative_decoding_use_logprobs * fix: test_completion_echo * fix test_max_streaming_tokens --------- Co-authored-by: Jiang-Jia-Jun <163579578+Jiang-Jia-Jun@users.noreply.github.com>
This commit is contained in:
@@ -22,6 +22,7 @@ import traceback
|
||||
import weakref
|
||||
from collections import Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
@@ -67,11 +68,20 @@ class TokenProcessor:
|
||||
self.use_logprobs = self.cfg.model_config.enable_logprob
|
||||
|
||||
if self.speculative_decoding:
|
||||
self.output_tokens = paddle.full(
|
||||
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2],
|
||||
fill_value=2,
|
||||
dtype="int64",
|
||||
)
|
||||
if self.use_logprobs:
|
||||
self.output_tokens = paddle.full(
|
||||
shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], fill_value=2, dtype="int64"
|
||||
)
|
||||
self.output_scores = paddle.full(
|
||||
shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], fill_value=0.0, dtype="float32"
|
||||
)
|
||||
self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64")
|
||||
else:
|
||||
self.output_tokens = paddle.full(
|
||||
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2],
|
||||
fill_value=2,
|
||||
dtype="int64",
|
||||
)
|
||||
elif self.use_logprobs:
|
||||
self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
|
||||
self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
|
||||
@@ -107,6 +117,7 @@ class TokenProcessor:
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.prefill_result_status = dict()
|
||||
self._finalizer = weakref.finalize(self, self._cleanup_resources)
|
||||
self._batch_result_buffer = None
|
||||
|
||||
def _cleanup_resources(self):
|
||||
"""Cleaning up shared memory resources"""
|
||||
@@ -312,6 +323,7 @@ class TokenProcessor:
|
||||
get_output_ep,
|
||||
get_output_topk,
|
||||
speculate_get_output,
|
||||
speculate_get_output_topk,
|
||||
)
|
||||
rank_id = self.cfg.parallel_config.local_data_parallel_id
|
||||
|
||||
@@ -319,15 +331,27 @@ class TokenProcessor:
|
||||
try:
|
||||
is_blocking = True
|
||||
if self.speculative_decoding:
|
||||
if (
|
||||
self.cfg.parallel_config.enable_expert_parallel
|
||||
and self.cfg.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
|
||||
if self.use_logprobs:
|
||||
speculate_get_output_topk(
|
||||
self.output_tokens,
|
||||
self.output_scores,
|
||||
self.output_ranks,
|
||||
K,
|
||||
rank_id,
|
||||
is_blocking,
|
||||
)
|
||||
if self.output_tokens[0, 0] == -2:
|
||||
continue
|
||||
else:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
if (
|
||||
self.cfg.parallel_config.enable_expert_parallel
|
||||
and self.cfg.parallel_config.data_parallel_size > 1
|
||||
):
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, True)
|
||||
else:
|
||||
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
|
||||
if self.output_tokens[0] == -2:
|
||||
continue
|
||||
else:
|
||||
if self.use_logprobs:
|
||||
get_output_topk(
|
||||
@@ -372,7 +396,7 @@ class TokenProcessor:
|
||||
|
||||
self.executor.submit(process_metrics)
|
||||
|
||||
def postprocess(self, batch_result):
|
||||
def postprocess(self, batch_result: List[RequestOutput], mtype=3):
|
||||
"""
|
||||
single post-processing function
|
||||
|
||||
@@ -380,7 +404,28 @@ class TokenProcessor:
|
||||
batch_result (list): batch results
|
||||
"""
|
||||
try:
|
||||
self.cached_generated_tokens.put_results(batch_result)
|
||||
if self.cfg.speculative_config.method and self.use_logprobs:
|
||||
if mtype == 3: # target
|
||||
finished_batch_result, unfinished_batch_result = [], []
|
||||
for r in batch_result:
|
||||
(finished_batch_result if r.finished else unfinished_batch_result).append(r)
|
||||
if finished_batch_result:
|
||||
self.cached_generated_tokens.put_results(batch_result)
|
||||
else:
|
||||
self._batch_result_buffer = unfinished_batch_result
|
||||
elif mtype == 4: # draft
|
||||
target_batch_result = []
|
||||
draft_batch_result = batch_result
|
||||
if self._batch_result_buffer is not None:
|
||||
for target, decode in zip(self._batch_result_buffer, draft_batch_result):
|
||||
target.outputs.draft_top_logprobs = decode.outputs.draft_top_logprobs
|
||||
target_batch_result.append(target)
|
||||
self._batch_result_buffer = None
|
||||
self.cached_generated_tokens.put_results(target_batch_result)
|
||||
else:
|
||||
self.cached_generated_tokens.put_results(batch_result)
|
||||
else:
|
||||
self.cached_generated_tokens.put_results(batch_result)
|
||||
except Exception as e:
|
||||
llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}")
|
||||
|
||||
@@ -471,9 +516,25 @@ class TokenProcessor:
|
||||
tokens = self.output_tokens.numpy()
|
||||
scores = None
|
||||
ranks = None
|
||||
# target:3, draft:4
|
||||
mtype = 3
|
||||
if self.cfg.speculative_config.method:
|
||||
batch = self.output_tokens[1]
|
||||
accept_num = tokens[2 : batch + 2]
|
||||
if self.use_logprobs:
|
||||
mtype = int(self.output_tokens[1, 0].item())
|
||||
batch = self.output_tokens[2, 0]
|
||||
accept_num = [int(num[0]) for num in self.output_tokens[3 : batch + 3]]
|
||||
tokens = tokens[3 + MAX_BSZ : 3 + MAX_BSZ + batch * MAX_DRAFT_TOKENS * (K + 1)].reshape(
|
||||
[batch, MAX_DRAFT_TOKENS, K + 1]
|
||||
)
|
||||
scores = (
|
||||
self.output_scores[: batch * MAX_DRAFT_TOKENS * (K + 1)]
|
||||
.numpy()
|
||||
.reshape([batch, MAX_DRAFT_TOKENS, K + 1])
|
||||
)
|
||||
ranks = self.output_ranks[: batch * MAX_DRAFT_TOKENS].numpy().reshape([batch, MAX_DRAFT_TOKENS])
|
||||
else:
|
||||
batch = self.output_tokens[1]
|
||||
accept_num = tokens[2 : batch + 2]
|
||||
self._record_speculative_decoding_mertics(accept_num)
|
||||
elif self.use_logprobs:
|
||||
batch = self.output_tokens[1, 0]
|
||||
@@ -501,6 +562,8 @@ class TokenProcessor:
|
||||
if recovery_stop:
|
||||
llm_logger.info(f"recovery stop signal found at task {task_id}")
|
||||
token_ids = [RECOVERY_STOP_SIGNAL]
|
||||
elif self.use_logprobs:
|
||||
token_ids = tokens[i][:, 0].tolist()[: accept_num[i]]
|
||||
else:
|
||||
token_ids = tokens[
|
||||
2
|
||||
@@ -556,6 +619,7 @@ class TokenProcessor:
|
||||
self._record_metrics(task, current_time, token_ids)
|
||||
result = RequestOutput(
|
||||
request_id=task_id,
|
||||
output_type=mtype,
|
||||
outputs=CompletionOutput(
|
||||
index=i,
|
||||
send_idx=self.tokens_counter[task_id],
|
||||
@@ -575,29 +639,54 @@ class TokenProcessor:
|
||||
if is_prefill and len(token_ids) > 1:
|
||||
result.outputs.draft_token_ids = copy.deepcopy(token_ids)
|
||||
|
||||
for token_id in token_ids:
|
||||
for batch_token_index in range(len(token_ids)):
|
||||
token_id = token_ids[batch_token_index]
|
||||
self.tokens_counter[task_id] += 1
|
||||
if token_id != RECOVERY_STOP_SIGNAL:
|
||||
if not (envs.FD_ENABLE_INTERNAL_ADAPTER and token_id in task.eos_token_ids):
|
||||
result.outputs.token_ids.append(token_id)
|
||||
task.output_token_ids.append(token_id)
|
||||
if self.use_logprobs:
|
||||
result.outputs.logprob = float(scores[i, 0])
|
||||
# Construct top_logprobs
|
||||
topk_token_ids = tokens[i, :].tolist()
|
||||
topk_logprobs = scores[i, :].tolist()
|
||||
sampled_rank = ranks[i].item()
|
||||
result.outputs.top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank],
|
||||
)
|
||||
if token_id in task.eos_token_ids or is_prefill or recovery_stop:
|
||||
if self.cfg.speculative_config.method:
|
||||
result.outputs.logprob = float(scores[i, batch_token_index, 0])
|
||||
topk_token_ids = tokens[i, batch_token_index, :].tolist()
|
||||
topk_logprobs = scores[i, batch_token_index, :].tolist()
|
||||
sampled_rank = ranks[i, batch_token_index].item()
|
||||
else:
|
||||
result.outputs.logprob = float(scores[i, 0])
|
||||
topk_token_ids = tokens[i, :].tolist()
|
||||
topk_logprobs = scores[i, :].tolist()
|
||||
sampled_rank = ranks[i].item()
|
||||
|
||||
if mtype == 3: # top_logprobs
|
||||
if result.outputs.top_logprobs is None:
|
||||
result.outputs.top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank],
|
||||
)
|
||||
else:
|
||||
result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids])
|
||||
result.outputs.top_logprobs.logprobs.extend([topk_logprobs])
|
||||
result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank])
|
||||
elif mtype == 4: # draft_top_logprobs
|
||||
if result.outputs.draft_top_logprobs is None:
|
||||
result.outputs.draft_top_logprobs = LogprobsLists(
|
||||
logprob_token_ids=[topk_token_ids],
|
||||
logprobs=[topk_logprobs],
|
||||
sampled_token_ranks=[sampled_rank],
|
||||
)
|
||||
else:
|
||||
result.outputs.draft_top_logprobs.logprob_token_ids.extend([topk_token_ids])
|
||||
result.outputs.draft_top_logprobs.logprobs.extend([topk_logprobs])
|
||||
result.outputs.draft_top_logprobs.sampled_token_ranks.extend([sampled_rank])
|
||||
if mtype == 3 and (token_id in task.eos_token_ids or is_prefill or recovery_stop):
|
||||
result.finished = True
|
||||
if recovery_stop:
|
||||
result.error_msg = "Recover is not supported, the result is incomplete!"
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} finished, number of " f"generated tokens: {self.tokens_counter[task_id]}."
|
||||
f"Request: {task_id} finished, number of "
|
||||
f"generated tokens: {self.tokens_counter[task_id]}, token_id:{token_id},is_prefill:{is_prefill},recovery_stop:{recovery_stop}"
|
||||
)
|
||||
llm_logger.info(
|
||||
f"Request: {task_id} token ratio: {self.tokens_counter[task_id] / (time.time() - task.inference_start_time)}"
|
||||
@@ -616,7 +705,7 @@ class TokenProcessor:
|
||||
):
|
||||
batch_result.append(result)
|
||||
|
||||
self.postprocess(batch_result)
|
||||
self.postprocess(batch_result, mtype)
|
||||
|
||||
def _record_metrics(self, task, current_time, token_ids):
|
||||
"""Record all metrics for a task"""
|
||||
|
||||
Reference in New Issue
Block a user