[Speculative Decoding]Support different inferseed in speculate decoding (#5569)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled

* fix mtp entropy drop in RL

* fix unit test

* optimize padding_sampling_params speed(vectorized)
This commit is contained in:
freeliuzc
2025-12-17 16:54:24 +08:00
committed by GitHub
parent 8981ce8fa3
commit 19653ee03a
3 changed files with 56 additions and 16 deletions

View File

@@ -61,12 +61,40 @@ def top_p_normalize_probs_paddle(
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
def padding_sampling_params(top_p, top_k, seq_lens_this_time, seq_lens_encoder):
def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder):
real_bsz = seq_lens_this_time.shape[0]
repeats = paddle.where(seq_lens_encoder[:real_bsz] == 0, seq_lens_this_time, paddle.ones_like(seq_lens_this_time))
top_p_padding = paddle.repeat_interleave(top_p[:real_bsz], repeats).unsqueeze(1)
top_k_padding = paddle.repeat_interleave(top_k[:real_bsz], repeats).unsqueeze(1)
return top_p_padding, top_k_padding
topp_seed = paddle.repeat_interleave(infer_seed[:real_bsz], repeats).unsqueeze(1)
MAX_INFER_SEED = 9223372036854775806
token_lens = paddle.where(
seq_lens_encoder[:real_bsz] == 0,
seq_lens_this_time,
paddle.ones_like(seq_lens_this_time),
)
batch_start = (paddle.cumsum(token_lens, axis=0) - token_lens.astype("int64")).reshape(-1) # [B]
token_batch_ids = paddle.repeat_interleave(
paddle.arange(token_lens.shape[0], dtype="int64"),
token_lens,
)
token_pos = paddle.arange(topp_seed.shape[0], dtype="int64")
local_pos = token_pos - paddle.gather(batch_start, token_batch_ids)
is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape(-1)
offsets = paddle.where(
is_decoder,
local_pos * 4,
paddle.zeros_like(local_pos),
)
topp_seed[:, 0] = (topp_seed[:, 0] + offsets) % MAX_INFER_SEED
return top_p_padding, top_k_padding, topp_seed
class SamplerProcessor:
@@ -353,7 +381,7 @@ class Sampler(nn.Layer):
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.top_k_list,
seed=sampling_metadata.seed[0, 0],
topp_seed=sampling_metadata.seed,
)
logprobs_tensors = (
@@ -529,13 +557,14 @@ class SpeculativeSampler(nn.Layer):
probs = F.softmax(logits)
top_p, top_k = padding_sampling_params(
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
share_inputs["infer_seed"],
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0])
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
@@ -788,13 +817,7 @@ class MTPSampler(nn.Layer):
)
probs = F.softmax(logits)
top_p, top_k = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, next_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0])
next_tokens = paddle.argmax(probs, axis=-1)
token_ids = None
logprobs_tensors = None

View File

@@ -144,9 +144,12 @@ class GPUModelRunner(ModelRunnerBase):
# Initialize share inputs
self._init_share_inputs(self.parallel_config.max_num_seqs)
increment_value = (
4 if not self.speculative_decoding else (self.speculative_config.num_speculative_tokens + 1) * 4
)
self.infer_seed_increment = paddle.full(
shape=[self.parallel_config.max_num_seqs, 1],
fill_value=4,
fill_value=increment_value,
dtype="int64",
).cpu()

View File

@@ -27,10 +27,13 @@ class TestPaddingSamplingParams(unittest.TestCase):
def test_all_decode(self):
top_p = paddle.to_tensor([0.8, 0.9, 0.95], dtype="float32")
top_k = paddle.to_tensor([10, 20, 30], dtype="int64")
seed = paddle.to_tensor([1, 2, 3], dtype="int64")
seq_lens_this_time = paddle.to_tensor([2, 3, 1], dtype="int64")
seq_lens_encoder = paddle.to_tensor([0, 0, 0], dtype="int64")
top_p_padding, top_k_padding = padding_sampling_params(top_p, top_k, seq_lens_this_time, seq_lens_encoder)
top_p_padding, top_k_padding, topp_seed = padding_sampling_params(
top_p, top_k, seed, seq_lens_this_time, seq_lens_encoder
)
expected_len = sum(seq_lens_this_time.numpy())
self.assertEqual(top_p_padding.shape[0], expected_len)
@@ -39,28 +42,39 @@ class TestPaddingSamplingParams(unittest.TestCase):
expected_top_p = np.repeat([0.8, 0.9, 0.95], [2, 3, 1]).reshape(-1, 1)
np.testing.assert_allclose(top_p_padding.numpy(), expected_top_p, rtol=1e-6)
expected_topp_seed = np.array([1, 5, 2, 6, 10, 3]).reshape(-1, 1)
np.testing.assert_allclose(topp_seed.numpy(), expected_topp_seed)
def test_partial_decode(self):
top_p = paddle.to_tensor([0.7, 0.6, 0.5], dtype="float32")
top_k = paddle.to_tensor([15, 25, 35], dtype="int64")
seed = paddle.to_tensor([1, 2, 3], dtype="int64")
seq_lens_this_time = paddle.to_tensor([3, 2, 4], dtype="int64")
seq_lens_encoder = paddle.to_tensor([0, 1, 0], dtype="int64")
top_p_padding, top_k_padding = padding_sampling_params(top_p, top_k, seq_lens_this_time, seq_lens_encoder)
top_p_padding, top_k_padding, topp_seed = padding_sampling_params(
top_p, top_k, seed, seq_lens_this_time, seq_lens_encoder
)
expected_repeats = [3, 1, 4]
expected_top_p = np.repeat([0.7, 0.6, 0.5], expected_repeats).reshape(-1, 1)
expected_top_k = np.repeat([15, 25, 35], expected_repeats).reshape(-1, 1)
expected_topp_seed = np.array([1, 5, 9, 2, 3, 7, 11, 15]).reshape(-1, 1)
np.testing.assert_allclose(top_p_padding.numpy(), expected_top_p, rtol=1e-6)
np.testing.assert_array_equal(top_k_padding.numpy(), expected_top_k)
np.testing.assert_array_equal(topp_seed.numpy(), expected_topp_seed)
def test_all_prefill(self):
top_p = paddle.to_tensor([0.5, 0.6], dtype="float32")
top_k = paddle.to_tensor([5, 6], dtype="int64")
seed = paddle.to_tensor([1, 1], dtype="int64")
seq_lens_this_time = paddle.to_tensor([4, 3], dtype="int64")
seq_lens_encoder = paddle.to_tensor([1, 2], dtype="int64")
top_p_padding, top_k_padding = padding_sampling_params(top_p, top_k, seq_lens_this_time, seq_lens_encoder)
top_p_padding, top_k_padding, topp_seed = padding_sampling_params(
top_p, top_k, seed, seq_lens_this_time, seq_lens_encoder
)
expected_top_p = np.array([[0.5], [0.6]])
expected_top_k = np.array([[5], [6]])