From a7359d1c1d3bea9fd56fce0804f982eeac4dbd84 Mon Sep 17 00:00:00 2001 From: freeliuzc Date: Wed, 17 Dec 2025 16:53:47 +0800 Subject: [PATCH] [Cherry-Pick][CI]Support different inferseed in speculate decoding(#5568) (#5597) * fix mtp entropy drop in RL * optimize usage and fix unit test * optimize padding_sampling_params speed(vectorized) --- .../model_executor/layers/sample/sampler.py | 47 +++++++++++---- fastdeploy/worker/gpu_model_runner.py | 6 +- tests/layers/test_speculative_sampler.py | 60 ++++++++++++++++++- 3 files changed, 99 insertions(+), 14 deletions(-) diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 52bb358bf..28687ea53 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -56,12 +56,40 @@ def top_p_normalize_probs_paddle( return paddle.zeros_like(probs_sort).put_along_axis_(indices=probs_idx, values=probs_sort, axis=-1) -def padding_sampling_params(top_p, top_k, seq_lens_this_time, seq_lens_encoder): +def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder): real_bsz = seq_lens_this_time.shape[0] repeats = paddle.where(seq_lens_encoder[:real_bsz] == 0, seq_lens_this_time, paddle.ones_like(seq_lens_this_time)) top_p_padding = paddle.repeat_interleave(top_p[:real_bsz], repeats).unsqueeze(1) top_k_padding = paddle.repeat_interleave(top_k[:real_bsz], repeats).unsqueeze(1) - return top_p_padding, top_k_padding + topp_seed = paddle.repeat_interleave(infer_seed[:real_bsz], repeats).unsqueeze(1) + + MAX_INFER_SEED = 9223372036854775806 + + token_lens = paddle.where( + seq_lens_encoder[:real_bsz] == 0, + seq_lens_this_time, + paddle.ones_like(seq_lens_this_time), + ) + + batch_start = (paddle.cumsum(token_lens, axis=0) - token_lens.astype("int64")).reshape(-1) # [B] + token_batch_ids = paddle.repeat_interleave( + paddle.arange(token_lens.shape[0], dtype="int64"), + token_lens, + ) + token_pos = paddle.arange(topp_seed.shape[0], dtype="int64") + local_pos = token_pos - paddle.gather(batch_start, token_batch_ids) + + is_decoder = paddle.gather(seq_lens_encoder[:real_bsz] == 0, token_batch_ids).reshape(-1) + + offsets = paddle.where( + is_decoder, + local_pos * 4, + paddle.zeros_like(local_pos), + ) + + topp_seed[:, 0] = (topp_seed[:, 0] + offsets) % MAX_INFER_SEED + + return top_p_padding, top_k_padding, topp_seed class GuidedDecoding: @@ -501,7 +529,7 @@ class Sampler(nn.Layer): sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list, - seed=sampling_metadata.seed[0, 0], + topp_seed=sampling_metadata.seed, ) logprobs_tensors = ( @@ -725,13 +753,14 @@ class SpeculativeSampler(nn.Layer): probs = F.softmax(logits) - top_p, top_k = padding_sampling_params( + top_p, top_k, topp_seed = padding_sampling_params( sampling_metadata.top_p, sampling_metadata.top_k, + sampling_metadata.seed, share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], ) - _, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0]) + _, sampled_token_ids = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, topp_seed=topp_seed) verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( probs, @@ -1064,13 +1093,7 @@ class MTPSampler(nn.Layer): ) probs = F.softmax(logits) - top_p, top_k = padding_sampling_params( - sampling_metadata.top_p, - sampling_metadata.top_k, - share_inputs["seq_lens_this_time"], - share_inputs["seq_lens_encoder"], - ) - _, next_tokens = top_k_top_p_sampling(probs, top_p=top_p, top_k=top_k, seed=sampling_metadata.seed[0, 0]) + next_tokens = paddle.argmax(probs, axis=-1) token_ids = None logprobs_tensors = None diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index b3fcdab44..418fef909 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -190,9 +190,13 @@ class GPUModelRunner(ModelRunnerBase): # Initialize share inputs self._init_share_inputs(self.scheduler_config.max_num_seqs) + increment_value = ( + 4 if not self.speculative_decoding else (self.speculative_config.num_speculative_tokens + 1) * 4 + ) + self.infer_seed_increment = paddle.full( shape=[self.scheduler_config.max_num_seqs, 1], - fill_value=4, + fill_value=increment_value, dtype="int64", ).cpu() diff --git a/tests/layers/test_speculative_sampler.py b/tests/layers/test_speculative_sampler.py index 32f95bfd9..e14503071 100644 --- a/tests/layers/test_speculative_sampler.py +++ b/tests/layers/test_speculative_sampler.py @@ -30,6 +30,7 @@ from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata from fastdeploy.model_executor.layers.sample.sampler import ( MTPSampler, SpeculativeSampler, + padding_sampling_params, ) @@ -72,7 +73,7 @@ def _create_default_sampling_metadata( bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"), eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"), min_p=paddle.randn([batch_size]), - seed=paddle.to_tensor([[2025]]), + seed=paddle.full(shape=[batch_size], fill_value=0, dtype="int64"), ) if max_num_logprobs is not None: fake_sampling_metadata.max_num_logprobs = max_num_logprobs @@ -143,6 +144,19 @@ def _create_share_inputs(max_num_seqs, max_draft_token_num, max_model_len, vocab return share_inputs +def _create_padding_inputs(): + # batch_size = 3 + top_p = paddle.to_tensor([[0.9], [0.8], [0.7], [1.0]], dtype="float32") + top_k = paddle.to_tensor([[10], [20], [30], [40]], dtype="int32") + infer_seed = paddle.to_tensor([[100], [200], [300], [400]], dtype="int64") + + # decoder, encoder, decoder + seq_lens_encoder = paddle.to_tensor([[0], [5], [0], [0]], dtype="int32") + seq_lens_this_time = paddle.to_tensor([[3], [2], [0], [2]], dtype="int32") + + return top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder + + def test_speculative_sampler(): batch_size = 32 vocab_size = 1024 @@ -220,8 +234,52 @@ def test_mtp_sampler_logprobs(): sampler(logits, sampling_metadata, max_model_len, share_inputs) +def test_padding_sampling_params_basic(): + top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder = _create_padding_inputs() + + top_p_pad, top_k_pad, seed_pad = padding_sampling_params( + top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder + ) + + # decoder(3) + encoder(1) + decoder(2) = 6 + assert top_p_pad.shape == [6, 1] + assert top_k_pad.shape == [6, 1] + assert seed_pad.shape == [6, 1] + + # top_p padding check + expected_top_p = [0.9, 0.9, 0.9, 0.8, 1.0, 1.0] + assert paddle.allclose(top_p_pad.squeeze(), paddle.to_tensor(expected_top_p, dtype="float32")) + + # top_k padding check + expected_top_k = [10, 10, 10, 20, 40, 40] + assert paddle.equal_all(top_k_pad.squeeze(), paddle.to_tensor(expected_top_k, dtype="int32")) + + +def test_padding_sampling_params_seed_offset(): + top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder = _create_padding_inputs() + + _, _, seed_pad = padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_lens_encoder) + + # decoder(0): 100 + 4*k + # encoder(1): 200 (no offset) + # null + # decoder(3): 400 + 4*k + expected_seed = [ + 100, + 104, + 108, # first decoder seq (len=3) + 200, # encoder + 400, + 404, # second decoder seq (len=2) + ] + + assert paddle.equal_all(seed_pad.squeeze(), paddle.to_tensor(expected_seed, dtype="int64")) + + if __name__ == "__main__": test_speculative_sampler() test_speculative_sampler_logprobs() test_mtp_sampler() test_mtp_sampler_logprobs() + test_padding_sampling_params_basic() + test_padding_sampling_params_seed_offset()