mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature] support mtp distribution equivalence verification (#4699)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
This commit is contained in:
@@ -53,6 +53,14 @@ def top_p_normalize_probs_paddle(
|
||||
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
|
||||
|
||||
|
||||
def padding_sampling_params(top_p, top_k, seq_lens_this_time, seq_lens_encoder):
|
||||
real_bsz = seq_lens_this_time.shape[0]
|
||||
repeats = paddle.where(seq_lens_encoder[:real_bsz] == 0, seq_lens_this_time, paddle.ones_like(seq_lens_this_time))
|
||||
top_p_padding = paddle.repeat_interleave(top_p[:real_bsz], repeats).unsqueeze(1)
|
||||
top_k_padding = paddle.repeat_interleave(top_k[:real_bsz], repeats).unsqueeze(1)
|
||||
return top_p_padding, top_k_padding
|
||||
|
||||
|
||||
class GuidedDecoding:
|
||||
"""
|
||||
processor for guided decoding.
|
||||
@@ -595,6 +603,14 @@ class SpeculativeSampler(nn.Layer):
|
||||
|
||||
probs = F.softmax(logits)
|
||||
|
||||
top_p, top_k = padding_sampling_params(
|
||||
sampling_metadata.top_p,
|
||||
sampling_metadata.top_k,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
)
|
||||
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0])
|
||||
|
||||
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
|
||||
probs,
|
||||
sampling_metadata.top_p,
|
||||
@@ -604,6 +620,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
)
|
||||
|
||||
speculate_verify(
|
||||
sampled_token_ids,
|
||||
share_inputs["accept_tokens"],
|
||||
share_inputs["accept_num"],
|
||||
share_inputs["step_idx"],
|
||||
@@ -849,9 +866,13 @@ class MTPSampler(nn.Layer):
|
||||
)
|
||||
probs = F.softmax(logits)
|
||||
|
||||
_, next_tokens = top_k_top_p_sampling(
|
||||
probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list
|
||||
top_p, top_k = padding_sampling_params(
|
||||
sampling_metadata.top_p,
|
||||
sampling_metadata.top_k,
|
||||
share_inputs["seq_lens_this_time"],
|
||||
share_inputs["seq_lens_encoder"],
|
||||
)
|
||||
_, next_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0])
|
||||
|
||||
token_ids = None
|
||||
logprobs_tensors = None
|
||||
|
||||
Reference in New Issue
Block a user