[XPU] modify speculate_verify (#5522)

This commit is contained in:
RuohengMa
2025-12-23 14:50:30 +08:00
committed by GitHub
parent 945a1bc4e2
commit 2c3c983b96
7 changed files with 278 additions and 354 deletions

View File

@@ -71,7 +71,7 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le
paddle.ones_like(seq_lens_this_time),
)
batch_start = (paddle.cumsum(token_lens, axis=0) - token_lens.astype("int64")).reshape(-1) # [B]
batch_start = (paddle.cumsum(token_lens, axis=0, dtype="int64") - token_lens.astype("int64")).reshape([-1]) # [B]
token_batch_ids = paddle.repeat_interleave(
paddle.arange(token_lens.shape[0], dtype="int64"),
token_lens,
@@ -79,7 +79,7 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le
token_pos = paddle.arange(topp_seed.shape[0], dtype="int64")
local_pos = token_pos - paddle.gather(batch_start, token_batch_ids)
is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape(-1)
is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape([-1])
offsets = paddle.where(
is_decoder,
@@ -879,6 +879,15 @@ class SpeculativeSampler(nn.Layer):
probs = F.softmax(logits)
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
share_inputs["seq_lens_this_time"],
paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]),
)
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
sampling_metadata.top_p,
@@ -888,6 +897,7 @@ class SpeculativeSampler(nn.Layer):
)
speculate_verify(
sampled_token_ids,
share_inputs["accept_tokens"],
share_inputs["accept_num"],
share_inputs["step_idx"],