mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Add ut for speculative sampler (#4650)
This commit is contained in:
@@ -480,7 +480,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs = sampling_metadata.share_inputs
|
||||
last_logits = logits
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
batch_token_num = share_inputs["batch_token_num"][:real_bsz]
|
||||
batch_token_num = share_inputs["accept_num"][:real_bsz]
|
||||
|
||||
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||
@@ -637,7 +637,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
batch_token_num = paddle.where(
|
||||
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
|
||||
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
|
||||
share_inputs["accept_num"][:real_bsz].unsqueeze(1),
|
||||
share_inputs["seq_lens_this_time"],
|
||||
).squeeze(1)
|
||||
share_inputs["batch_token_num"] = batch_token_num
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
@@ -647,11 +647,11 @@ class SpeculativeSampler(nn.Layer):
|
||||
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
|
||||
).astype("int32")
|
||||
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
|
||||
target_logtis = paddle.empty(
|
||||
target_logits = paddle.empty(
|
||||
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
|
||||
)
|
||||
speculate_get_target_logits(
|
||||
target_logtis,
|
||||
target_logits,
|
||||
logits,
|
||||
cu_batch_token_offset,
|
||||
ori_cu_batch_token_offset,
|
||||
@@ -660,25 +660,22 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs["accept_num"],
|
||||
)
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
|
||||
raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata)
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = target_logtis.clone()
|
||||
raw_logprobs = target_logits.clone()
|
||||
|
||||
logprobs_tensors = None
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
if num_logprobs is not None:
|
||||
token_ids = paddle.concat(
|
||||
[
|
||||
share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]]
|
||||
for i in range(share_inputs["accept_num"][:real_bsz].shape[0])
|
||||
]
|
||||
[share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)]
|
||||
)
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=batch_token_num,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user