[Speculative Decoding]Support different inferseed in speculate decoding (#5568)

* fix mtp entropy drop in RL

* optimize usage and fix unit test

* optimize padding_sampling_params speed(vectorized)

---------

Co-authored-by: YuBaoku <49938469+EmmonsCurse@users.noreply.github.com>
This commit is contained in:
freeliuzc
2025-12-17 16:14:29 +08:00
committed by GitHub
parent 80fb530ce2
commit 15f5112ecb
3 changed files with 99 additions and 14 deletions

View File

@@ -56,12 +56,40 @@ def top_p_normalize_probs_paddle(
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
def padding_sampling_params(top_p, top_k, seq_lens_this_time, seq_lens_encoder):
def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder):
real_bsz = seq_lens_this_time.shape[0]
repeats = paddle.where(seq_lens_encoder[:real_bsz] == 0, seq_lens_this_time, paddle.ones_like(seq_lens_this_time))
top_p_padding = paddle.repeat_interleave(top_p[:real_bsz], repeats).unsqueeze(1)
top_k_padding = paddle.repeat_interleave(top_k[:real_bsz], repeats).unsqueeze(1)
return top_p_padding, top_k_padding
topp_seed = paddle.repeat_interleave(infer_seed[:real_bsz], repeats).unsqueeze(1)
MAX_INFER_SEED = 9223372036854775806
token_lens = paddle.where(
seq_lens_encoder[:real_bsz] == 0,
seq_lens_this_time,
paddle.ones_like(seq_lens_this_time),
)
batch_start = (paddle.cumsum(token_lens, axis=0) - token_lens.astype("int64")).reshape(-1) # [B]
token_batch_ids = paddle.repeat_interleave(
paddle.arange(token_lens.shape[0], dtype="int64"),
token_lens,
)
token_pos = paddle.arange(topp_seed.shape[0], dtype="int64")
local_pos = token_pos - paddle.gather(batch_start, token_batch_ids)
is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape(-1)
offsets = paddle.where(
is_decoder,
local_pos * 4,
paddle.zeros_like(local_pos),
)
topp_seed[:, 0] = (topp_seed[:, 0] + offsets) % MAX_INFER_SEED
return top_p_padding, top_k_padding, topp_seed
class GuidedDecoding:
@@ -501,7 +529,7 @@ class Sampler(nn.Layer):
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.top_k_list,
seed=sampling_metadata.seed[0, 0],
topp_seed=sampling_metadata.seed,
)
logprobs_tensors = (
@@ -725,13 +753,14 @@ class SpeculativeSampler(nn.Layer):
probs = F.softmax(logits)
top_p, top_k = padding_sampling_params(
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0])
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
@@ -1064,13 +1093,7 @@ class MTPSampler(nn.Layer):
)
probs = F.softmax(logits)
top_p, top_k = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, next_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0])
next_tokens = paddle.argmax(probs, axis=-1)
token_ids = None
logprobs_tensors = None