[Feature] support mtp distribution equivalence verification (#4699)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

This commit is contained in:
GoldPancake
2025-10-31 11:45:04 +08:00
committed by GitHub
parent 28de91b50f
commit 1f3ce65b58
6 changed files with 257 additions and 88 deletions

View File

@@ -53,6 +53,14 @@ def top_p_normalize_probs_paddle(
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
def padding_sampling_params(top_p, top_k, seq_lens_this_time, seq_lens_encoder):
real_bsz = seq_lens_this_time.shape[0]
repeats = paddle.where(seq_lens_encoder[:real_bsz] == 0, seq_lens_this_time, paddle.ones_like(seq_lens_this_time))
top_p_padding = paddle.repeat_interleave(top_p[:real_bsz], repeats).unsqueeze(1)
top_k_padding = paddle.repeat_interleave(top_k[:real_bsz], repeats).unsqueeze(1)
return top_p_padding, top_k_padding
class GuidedDecoding:
"""
processor for guided decoding.
@@ -595,6 +603,14 @@ class SpeculativeSampler(nn.Layer):
probs = F.softmax(logits)
top_p, top_k = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0])
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
sampling_metadata.top_p,
@@ -604,6 +620,7 @@ class SpeculativeSampler(nn.Layer):
)
speculate_verify(
sampled_token_ids,
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["step_idx"],
@@ -849,9 +866,13 @@ class MTPSampler(nn.Layer):
)
probs = F.softmax(logits)
_, next_tokens = top_k_top_p_sampling(
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
top_p, top_k = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, next_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0])
token_ids = None
logprobs_tensors = None