[Cherry-Pick][CI]Support different inferseed in speculate decoding(#5568) (#5597)

* fix mtp entropy drop in RL

* optimize usage and fix unit test

* optimize padding_sampling_params speed(vectorized)
This commit is contained in:
freeliuzc
2025-12-17 16:53:47 +08:00
committed by GitHub
parent c19af496cb
commit a7359d1c1d
3 changed files with 99 additions and 14 deletions

View File

@@ -56,12 +56,40 @@ def top_p_normalize_probs_paddle(
return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1)
def padding_sampling_params(top_p, top_k, seq_lens_this_time, seq_lens_encoder):
def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder):
real_bsz = seq_lens_this_time.shape[0]
repeats = paddle.where(seq_lens_encoder[:real_bsz] == 0, seq_lens_this_time, paddle.ones_like(seq_lens_this_time))
top_p_padding = paddle.repeat_interleave(top_p[:real_bsz], repeats).unsqueeze(1)
top_k_padding = paddle.repeat_interleave(top_k[:real_bsz], repeats).unsqueeze(1)
return top_p_padding, top_k_padding
topp_seed = paddle.repeat_interleave(infer_seed[:real_bsz], repeats).unsqueeze(1)
MAX_INFER_SEED = 9223372036854775806
token_lens = paddle.where(
seq_lens_encoder[:real_bsz] == 0,
seq_lens_this_time,
paddle.ones_like(seq_lens_this_time),
)
batch_start = (paddle.cumsum(token_lens, axis=0) - token_lens.astype("int64")).reshape(-1) # [B]
token_batch_ids = paddle.repeat_interleave(
paddle.arange(token_lens.shape[0], dtype="int64"),
token_lens,
)
token_pos = paddle.arange(topp_seed.shape[0], dtype="int64")
local_pos = token_pos - paddle.gather(batch_start, token_batch_ids)
is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape(-1)
offsets = paddle.where(
is_decoder,
local_pos * 4,
paddle.zeros_like(local_pos),
)
topp_seed[:, 0] = (topp_seed[:, 0] + offsets) % MAX_INFER_SEED
return top_p_padding, top_k_padding, topp_seed
class GuidedDecoding:
@@ -501,7 +529,7 @@ class Sampler(nn.Layer):
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.top_k_list,
seed=sampling_metadata.seed[0, 0],
topp_seed=sampling_metadata.seed,
)
logprobs_tensors = (
@@ -725,13 +753,14 @@ class SpeculativeSampler(nn.Layer):
probs = F.softmax(logits)
top_p, top_k = padding_sampling_params(
top_p, top_k, topp_seed = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
sampling_metadata.seed,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0])
_, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed)
verify_scores, verify_tokens, actual_candidate_len = top_p_candidates(
probs,
@@ -1064,13 +1093,7 @@ class MTPSampler(nn.Layer):
)
probs = F.softmax(logits)
top_p, top_k = padding_sampling_params(
sampling_metadata.top_p,
sampling_metadata.top_k,
share_inputs["seq_lens_this_time"],
share_inputs["seq_lens_encoder"],
)
_, next_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0])
next_tokens = paddle.argmax(probs, axis=-1)
token_ids = None
logprobs_tensors = None

View File

@@ -190,9 +190,13 @@ class GPUModelRunner(ModelRunnerBase):
# Initialize share inputs
self._init_share_inputs(self.scheduler_config.max_num_seqs)
increment_value = (
4 if not self.speculative_decoding else (self.speculative_config.num_speculative_tokens + 1) * 4
)
self.infer_seed_increment = paddle.full(
shape=[self.scheduler_config.max_num_seqs, 1],
fill_value=4,
fill_value=increment_value,
dtype="int64",
).cpu()

View File

@@ -30,6 +30,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import (
MTPSampler,
SpeculativeSampler,
padding_sampling_params,
)
@@ -72,7 +73,7 @@ def _create_default_sampling_metadata(
bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"),
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
min_p=paddle.randn([batch_size]),
seed=paddle.to_tensor([[2025]]),
seed=paddle.full(shape=[batch_size], fill_value=0, dtype="int64"),
)
if max_num_logprobs is not None:
fake_sampling_metadata.max_num_logprobs = max_num_logprobs
@@ -143,6 +144,19 @@ def _create_share_inputs(max_num_seqs, max_draft_token_num, max_model_len, vocab
return share_inputs
def _create_padding_inputs():
# batch_size = 3
top_p = paddle.to_tensor([[0.9], [0.8], [0.7], [1.0]], dtype="float32")
top_k = paddle.to_tensor([[10], [20], [30], [40]], dtype="int32")
infer_seed = paddle.to_tensor([[100], [200], [300], [400]], dtype="int64")
# decoder, encoder, decoder
seq_lens_encoder = paddle.to_tensor([[0], [5], [0], [0]], dtype="int32")
seq_lens_this_time = paddle.to_tensor([[3], [2], [0], [2]], dtype="int32")
return top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder
def test_speculative_sampler():
batch_size = 32
vocab_size = 1024
@@ -220,8 +234,52 @@ def test_mtp_sampler_logprobs():
sampler(logits, sampling_metadata, max_model_len, share_inputs)
def test_padding_sampling_params_basic():
top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder = _create_padding_inputs()
top_p_pad, top_k_pad, seed_pad = padding_sampling_params(
top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder
)
# decoder(3) + encoder(1) + decoder(2) = 6
assert top_p_pad.shape == [6, 1]
assert top_k_pad.shape == [6, 1]
assert seed_pad.shape == [6, 1]
# top_p padding check
expected_top_p = [0.9, 0.9, 0.9, 0.8, 1.0, 1.0]
assert paddle.allclose(top_p_pad.squeeze(), paddle.to_tensor(expected_top_p, dtype="float32"))
# top_k padding check
expected_top_k = [10, 10, 10, 20, 40, 40]
assert paddle.equal_all(top_k_pad.squeeze(), paddle.to_tensor(expected_top_k, dtype="int32"))
def test_padding_sampling_params_seed_offset():
top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder = _create_padding_inputs()
_, _, seed_pad = padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder)
# decoder(0): 100 + 4*k
# encoder(1): 200 (no offset)
# null
# decoder(3): 400 + 4*k
expected_seed = [
100,
104,
108, # first decoder seq (len=3)
200, # encoder
400,
404, # second decoder seq (len=2)
]
assert paddle.equal_all(seed_pad.squeeze(), paddle.to_tensor(expected_seed, dtype="int64"))
if __name__ == "__main__":
test_speculative_sampler()
test_speculative_sampler_logprobs()
test_mtp_sampler()
test_mtp_sampler_logprobs()
test_padding_sampling_params_basic()
test_padding_sampling_params_seed_offset()