support logprob in mtp

This commit is contained in:
Deleter-D
2025-09-25 20:11:57 +08:00
parent bab779011c
commit a46aa06194
11 changed files with 1171 additions and 25 deletions

View File

@@ -60,11 +60,20 @@ class TokenProcessor:
self.use_logprobs = self.cfg.model_config.enable_logprob
if self.speculative_decoding:
self.output_tokens = paddle.full(
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2],
fill_value=2,
dtype="int64",
)
if self.use_logprobs:
self.output_tokens = paddle.full(
shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1) + MAX_BSZ + 3, 1], fill_value=2, dtype="int64"
)
self.output_scores = paddle.full(
shape=[MAX_BSZ * MAX_DRAFT_TOKENS * (K + 1), 1], fill_value=0.0, dtype="float32"
)
self.output_ranks = paddle.full(shape=[MAX_BSZ * MAX_DRAFT_TOKENS], fill_value=0, dtype="int64")
else:
self.output_tokens = paddle.full(
shape=[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2],
fill_value=2,
dtype="int64",
)
elif self.use_logprobs:
self.output_tokens = paddle.full(shape=[MAX_BSZ * (K + 1) + 2, 1], fill_value=2, dtype="int64")
self.output_scores = paddle.full(shape=[MAX_BSZ * (K + 1), 1], fill_value=0.0, dtype="float32")
@@ -149,6 +158,7 @@ class TokenProcessor:
get_output_ep,
get_output_topk,
speculate_get_output,
speculate_get_output_topk,
)
rank_id = self.cfg.parallel_config.local_data_parallel_id
@@ -156,9 +166,24 @@ class TokenProcessor:
try:
is_blocking = True
if self.speculative_decoding:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2:
continue
if self.use_logprobs:
speculate_get_output_topk(
self.output_tokens,
self.output_scores,
self.output_ranks,
K,
rank_id,
is_blocking,
)
print(f"[TokenProcessor] output_tokens: {self.output_tokens}")
print(f"[TokenProcessor] output_scores: {self.output_scores}")
print(f"[TokenProcessor] output_ranks: {self.output_ranks}")
if self.output_tokens[0, 0] == -2:
continue
else:
speculate_get_output(self.output_tokens, rank_id, is_blocking, False)
if self.output_tokens[0] == -2:
continue
else:
if self.use_logprobs: