Add ut for speculative sampler (#4650)

This commit is contained in:
GoldPancake
2025-10-30 10:37:49 +08:00
committed by GitHub
parent 1712e1351b
commit fddda50cb9
4 changed files with 478 additions and 11 deletions

View File

@@ -480,7 +480,7 @@ class SpeculativeSampler(nn.Layer):
share_inputs = sampling_metadata.share_inputs
last_logits = logits
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = share_inputs["batch_token_num"][:real_bsz]
batch_token_num = share_inputs["accept_num"][:real_bsz]
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
@@ -637,7 +637,7 @@ class SpeculativeSampler(nn.Layer):
batch_token_num = paddle.where(
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
share_inputs["accept_num"][:real_bsz].unsqueeze(1),
share_inputs["seq_lens_this_time"],
).squeeze(1)
share_inputs["batch_token_num"] = batch_token_num
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
@@ -647,11 +647,11 @@ class SpeculativeSampler(nn.Layer):
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
).astype("int32")
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
target_logtis = paddle.empty(
target_logits = paddle.empty(
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
)
speculate_get_target_logits(
target_logtis,
target_logits,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
@@ -660,25 +660,22 @@ class SpeculativeSampler(nn.Layer):
share_inputs["accept_num"],
)
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = target_logtis.clone()
raw_logprobs = target_logits.clone()
logprobs_tensors = None
token_ids = share_inputs["accept_tokens"]
if num_logprobs is not None:
token_ids = paddle.concat(
[
share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]]
for i in range(share_inputs["accept_num"][:real_bsz].shape[0])
]
[share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)]
)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=logprobs_tensors,
token_num_per_batch=batch_token_num,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
)