diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index c60ec8360..f88ce099f 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -480,7 +480,7 @@ class SpeculativeSampler(nn.Layer): share_inputs = sampling_metadata.share_inputs last_logits = logits real_bsz = share_inputs["seq_lens_this_time"].shape[0] - batch_token_num = share_inputs["batch_token_num"][:real_bsz] + batch_token_num = share_inputs["accept_num"][:real_bsz] temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs @@ -637,7 +637,7 @@ class SpeculativeSampler(nn.Layer): batch_token_num = paddle.where( share_inputs["seq_lens_encoder"][:real_bsz] != 0, paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]), - share_inputs["accept_num"][:real_bsz].unsqueeze(1), + share_inputs["seq_lens_this_time"], ).squeeze(1) share_inputs["batch_token_num"] = batch_token_num ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( @@ -647,11 +647,11 @@ class SpeculativeSampler(nn.Layer): [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])] ).astype("int32") share_inputs["cu_batch_token_offset"] = cu_batch_token_offset - target_logtis = paddle.empty( + target_logits = paddle.empty( [share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype ) speculate_get_target_logits( - target_logtis, + target_logits, logits, cu_batch_token_offset, ori_cu_batch_token_offset, @@ -660,25 +660,22 @@ class SpeculativeSampler(nn.Layer): share_inputs["accept_num"], ) if self.logprobs_mode == "raw_logprobs": - raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata) + raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata) elif self.logprobs_mode == "raw_logits": - raw_logprobs = target_logtis.clone() + raw_logprobs = target_logits.clone() logprobs_tensors = None token_ids = share_inputs["accept_tokens"] if num_logprobs is not None: token_ids = paddle.concat( - [ - share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] - for i in range(share_inputs["accept_num"][:real_bsz].shape[0]) - ] + [share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)] ) logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids) sampler_output = SamplerOutput( sampled_token_ids=token_ids, logprobs_tensors=logprobs_tensors, - token_num_per_batch=batch_token_num, + token_num_per_batch=share_inputs["accept_num"], cu_batch_token_offset=share_inputs["cu_batch_token_offset"], ) diff --git a/tests/layers/test_speculative_sampler.py b/tests/layers/test_speculative_sampler.py new file mode 100644 index 000000000..c5632d62e --- /dev/null +++ b/tests/layers/test_speculative_sampler.py @@ -0,0 +1,226 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from unittest.mock import Mock + +import paddle + +from fastdeploy.config import ( + CacheConfig, + FDConfig, + GraphOptimizationConfig, + ParallelConfig, + SchedulerConfig, + SpeculativeConfig, +) +from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata +from fastdeploy.model_executor.layers.sample.sampler import ( + MTPSampler, + SpeculativeSampler, +) + + +def _create_fake_logits(batch_size: int, vocab_size: int) -> paddle.Tensor: + fake_logits = paddle.rand(shape=[batch_size, vocab_size], dtype="float32") + return fake_logits + + +def _create_penalty_tensor(batch_size: int, penalty_value: float) -> paddle.Tensor: + return paddle.full(shape=[batch_size, 1], fill_value=penalty_value, dtype="float32") + + +def _create_tokens_tensor( + batch_size: int, + max_seq_len: int, +) -> paddle.Tensor: + pre_token_ids = paddle.full(shape=[batch_size, max_seq_len], fill_value=-1, dtype="int64") + return pre_token_ids + + +def _create_default_sampling_metadata( + batch_size: int, + min_seq_len: int, + max_seq_len: int, + max_num_logprobs: int = None, +) -> SamplingMetadata: + + fake_sampling_metadata = SamplingMetadata( + temperature=paddle.full(shape=[batch_size, 1], fill_value=0.9, dtype="float32"), + top_p=paddle.full(shape=[batch_size, 1], fill_value=0.7, dtype="float32"), + prompt_ids=paddle.full(shape=[batch_size, max_seq_len], fill_value=0, dtype="int64"), + prompt_lens=paddle.full(shape=[batch_size, 1], fill_value=5, dtype="int64"), + step_idx=paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64"), + pre_token_ids=_create_tokens_tensor(batch_size, max_seq_len), + frequency_penalties=_create_penalty_tensor(batch_size, 0.0), + presence_penalties=_create_penalty_tensor(batch_size, 0.0), + repetition_penalties=_create_penalty_tensor(batch_size, 1.0), + min_dec_lens=paddle.full(shape=[batch_size, 1], fill_value=min_seq_len, dtype="int64"), + bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"), + eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"), + min_p=paddle.randn([batch_size]), + seed=paddle.to_tensor([[2025]]), + ) + if max_num_logprobs is not None: + fake_sampling_metadata.max_num_logprobs = max_num_logprobs + return fake_sampling_metadata + + +def _create_fd_config(max_model_len): + model_config: Mock = Mock() + model_config.max_model_len = max_model_len + speculative_config = SpeculativeConfig({}) + graph_opt_config = GraphOptimizationConfig({}) + scheduler_config = SchedulerConfig({}) + parallel_config = ParallelConfig({}) + cache_config = CacheConfig({}) + cache_config.cache_transfer_protocol = "rdma,ipc" + cache_config.pd_comm_port = "2334" + fd_config = FDConfig( + model_config=model_config, + speculative_config=speculative_config, + graph_opt_config=graph_opt_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + cache_config=cache_config, + ) + + return fd_config + + +def _create_share_inputs(max_num_seqs, max_draft_token_num, max_model_len, vocab_size): + share_inputs = {} + share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 2, dtype="int32") + share_inputs["output_cum_offsets"] = paddle.concat( + [(max_model_len - share_inputs["seq_lens_this_time"][i]) * i for i in range(max_num_seqs)] + ) + share_inputs["output_padding_offset"] = paddle.repeat_interleave(share_inputs["output_cum_offsets"], 2) + share_inputs["accept_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64" + ) + share_inputs["accept_num"] = paddle.full(shape=[max_num_seqs], fill_value=1, dtype="int32") + share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 1, dtype="int64") + share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], False, dtype="bool") + share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") + share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 2, dtype="int32") + share_inputs["draft_tokens"] = paddle.full( + shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64" + ) + share_inputs["max_dec_len"] = paddle.full([max_num_seqs, 1], max_model_len, dtype="int64") + share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool") + share_inputs["actual_draft_token_num"] = paddle.full( + shape=[max_num_seqs], fill_value=max_draft_token_num, dtype="int32" + ) + + share_inputs["batch_token_num"] = paddle.where( + share_inputs["seq_lens_encoder"] != 0, + paddle.ones_like(share_inputs["seq_lens_encoder"]), + share_inputs["seq_lens_this_time"], + ).squeeze(1) + share_inputs["next_token_num"] = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32") + share_inputs["cu_batch_token_offset"] = paddle.concat( + [paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"])] + ).astype("int32") + share_inputs["cu_next_token_offset"] = paddle.full(shape=[max_num_seqs + 1], fill_value=0, dtype="int32") + share_inputs["substep"] = 0 + share_inputs["draft_logits"] = paddle.full( + [max_num_seqs * (max_draft_token_num + 1), vocab_size], -1, dtype="float32" + ) + + return share_inputs + + +def test_speculative_sampler(): + batch_size = 32 + vocab_size = 1024 + min_seq_len = 1 + max_seq_len = 1024 + max_model_len = 1024 + max_draft_token_num = 1 + + fd_config = _create_fd_config(max_model_len) + sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len) + logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) + share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) + + sampler = SpeculativeSampler(fd_config) + sampler(logits, sampling_metadata, max_model_len, share_inputs) + + +def test_speculative_sampler_logprobs(): + batch_size = 32 + vocab_size = 1024 + min_seq_len = 1 + max_seq_len = 1024 + max_model_len = 1024 + max_draft_token_num = 1 + + fd_config = _create_fd_config(max_model_len) + share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) + sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0) + sampling_metadata.share_inputs = share_inputs + logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) + + logprobs_mode_list = ["raw_logprobs", "raw_logits"] + for logprobs_mode in logprobs_mode_list: + fd_config.model_config.logprobs_mode = logprobs_mode + sampler = SpeculativeSampler(fd_config) + sampler(logits, sampling_metadata, max_model_len, share_inputs) + + +def test_mtp_sampler(): + batch_size = 32 + vocab_size = 1024 + min_seq_len = 1 + max_seq_len = 1024 + max_model_len = 1024 + max_draft_token_num = 1 + + fd_config = _create_fd_config(max_model_len) + sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len) + logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) + + share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) + + sampler = MTPSampler(fd_config) + sampler(logits, sampling_metadata, max_model_len, share_inputs) + + +def test_mtp_sampler_logprobs(): + batch_size = 32 + vocab_size = 1024 + min_seq_len = 1 + max_seq_len = 1024 + max_model_len = 1024 + max_draft_token_num = 1 + + fd_config = _create_fd_config(max_model_len) + share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size) + sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0) + sampling_metadata.share_inputs = share_inputs + logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size) + + logprobs_mode_list = ["raw_logprobs", "raw_logits"] + for logprobs_mode in logprobs_mode_list: + fd_config.model_config.logprobs_mode = logprobs_mode + sampler = MTPSampler(fd_config) + sampler(logits, sampling_metadata, max_model_len, share_inputs) + + +if __name__ == "__main__": + test_speculative_sampler() + test_speculative_sampler_logprobs() + test_mtp_sampler() + test_mtp_sampler_logprobs() diff --git a/tests/operators/test_speculate_get_target_logits.py b/tests/operators/test_speculate_get_target_logits.py new file mode 100644 index 000000000..5d930418a --- /dev/null +++ b/tests/operators/test_speculate_get_target_logits.py @@ -0,0 +1,144 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle + +from fastdeploy.model_executor.layers.sample.ops.speculate_logprob_utils import ( + speculate_get_target_logits, +) + + +class TestSpeculateInsertFirstToken(unittest.TestCase): + + def setUp(self): + self.vocab_size = 8192 + + def test_all_decode(self): + token_num = 6 + logits = paddle.full(shape=[token_num, self.vocab_size], fill_value=-1, dtype="float32") + for i in range(token_num): + logits[i][:] = i + + seq_lens_encoder = paddle.to_tensor([[0], [0], [0]], dtype="int32") + seq_lens_this_time = paddle.to_tensor([[2], [2], [2]], dtype="int32") + accept_num = paddle.to_tensor([1, 2, 1], dtype="int32") + batch_token_num = paddle.where( + seq_lens_encoder != 0, + paddle.ones_like(seq_lens_encoder), + seq_lens_this_time, + ).squeeze(1) + ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( + "int32" + ) + cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32") + target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype) + + speculate_get_target_logits( + target_logits, + logits, + cu_batch_token_offset, + ori_cu_batch_token_offset, + seq_lens_this_time, + seq_lens_encoder, + accept_num, + ) + + glod_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32") + glod_logits[0][:] = 0 + glod_logits[1][:] = 2 + glod_logits[2][:] = 3 + glod_logits[3][:] = 4 + + assert paddle.allclose(target_logits, glod_logits) + + def test_partial_decode(self): + token_num = 5 + logits = paddle.full(shape=[token_num, self.vocab_size], fill_value=-1, dtype="float32") + for i in range(token_num): + logits[i][:] = i + + seq_lens_encoder = paddle.to_tensor([[10], [0], [0]], dtype="int32") + seq_lens_this_time = paddle.to_tensor([[10], [2], [2]], dtype="int32") + accept_num = paddle.to_tensor([1, 2, 1], dtype="int32") + batch_token_num = paddle.where( + seq_lens_encoder != 0, + paddle.ones_like(seq_lens_encoder), + seq_lens_this_time, + ).squeeze(1) + ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( + "int32" + ) + cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32") + target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype) + + speculate_get_target_logits( + target_logits, + logits, + cu_batch_token_offset, + ori_cu_batch_token_offset, + seq_lens_this_time, + seq_lens_encoder, + accept_num, + ) + + glod_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32") + glod_logits[0][:] = 0 + glod_logits[1][:] = 1 + glod_logits[2][:] = 2 + glod_logits[3][:] = 3 + + assert paddle.allclose(target_logits, glod_logits) + + def test_all_prefill(self): + token_num = 3 + logits = paddle.full(shape=[token_num, self.vocab_size], fill_value=-1, dtype="float32") + for i in range(token_num): + logits[i][:] = i + + seq_lens_encoder = paddle.to_tensor([[10], [10], [10]], dtype="int32") + seq_lens_this_time = paddle.to_tensor([[10], [10], [10]], dtype="int32") + accept_num = paddle.to_tensor([1, 1, 1], dtype="int32") + batch_token_num = paddle.where( + seq_lens_encoder != 0, + paddle.ones_like(seq_lens_encoder), + seq_lens_this_time, + ).squeeze(1) + ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype( + "int32" + ) + cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32") + target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype) + + speculate_get_target_logits( + target_logits, + logits, + cu_batch_token_offset, + ori_cu_batch_token_offset, + seq_lens_this_time, + seq_lens_encoder, + accept_num, + ) + + glod_logits = paddle.full(shape=[3, self.vocab_size], fill_value=-1, dtype="float32") + glod_logits[0][:] = 0 + glod_logits[1][:] = 1 + glod_logits[2][:] = 2 + + assert paddle.allclose(target_logits, glod_logits) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/operators/test_speculate_insert_first_token.py b/tests/operators/test_speculate_insert_first_token.py new file mode 100644 index 000000000..1e65f2b08 --- /dev/null +++ b/tests/operators/test_speculate_insert_first_token.py @@ -0,0 +1,100 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle + +from fastdeploy.model_executor.layers.sample.ops.speculate_logprob_utils import ( + speculate_insert_first_token, +) + + +class TestSpeculateInsertFirstToken(unittest.TestCase): + + def test_all_decode(self): + token_num = 5 + accept_tokens = paddle.to_tensor([[1001, 1002], [1003, 1004], [1005, 1006]], dtype="int64") + next_tokens = paddle.to_tensor([[2001], [2002], [2003], [2004], [2005]], dtype="int64") + cu_next_token_offset = paddle.to_tensor([0, 2, 3, 5], dtype="int32") + cu_batch_token_offset = paddle.to_tensor([0, 2, 3, 5], dtype="int32") + seq_lens_this_time = paddle.to_tensor([[2], [1], [2]], dtype="int32") + seq_lens_encoder = paddle.to_tensor([[0], [0], [0]], dtype="int32") + + token_id = paddle.empty(token_num, dtype="int64") + speculate_insert_first_token( + token_id, + accept_tokens, + next_tokens, + cu_next_token_offset, + cu_batch_token_offset, + seq_lens_this_time, + seq_lens_encoder, + ) + + gold_token_id = paddle.to_tensor([2001, 2002, 2003, 2004, 2005], dtype="int64") + + assert paddle.equal_all(token_id, gold_token_id) + + def test_partial_decode(self): + token_num = 6 + accept_tokens = paddle.to_tensor([[1001, 1002], [1003, 1004], [1005, 1006]], dtype="int64") + next_tokens = paddle.to_tensor([[2001], [2002], [2003], [2004], [2005]], dtype="int64") + cu_next_token_offset = paddle.to_tensor([0, 2, 3, 5], dtype="int32") + cu_batch_token_offset = paddle.to_tensor([0, 2, 4, 6], dtype="int32") + seq_lens_this_time = paddle.to_tensor([[2], [10], [2]], dtype="int32") + seq_lens_encoder = paddle.to_tensor([[0], [10], [0]], dtype="int32") + + token_id = paddle.empty(token_num, dtype="int64") + speculate_insert_first_token( + token_id, + accept_tokens, + next_tokens, + cu_next_token_offset, + cu_batch_token_offset, + seq_lens_this_time, + seq_lens_encoder, + ) + + gold_token_id = paddle.to_tensor([2001, 2002, 1003, 2003, 2004, 2005], dtype="int64") + + assert paddle.equal_all(token_id, gold_token_id) + + def test_all_prefill(self): + token_num = 6 + accept_tokens = paddle.to_tensor([[1001, 1002], [1003, 1004], [1005, 1006]], dtype="int64") + next_tokens = paddle.to_tensor([[2001], [2002], [2003]], dtype="int64") + cu_next_token_offset = paddle.to_tensor([0, 1, 2, 3], dtype="int32") + cu_batch_token_offset = paddle.to_tensor([0, 2, 4, 6], dtype="int32") + seq_lens_this_time = paddle.to_tensor([[10], [10], [10]], dtype="int32") + seq_lens_encoder = paddle.to_tensor([[10], [10], [10]], dtype="int32") + + token_id = paddle.empty(token_num, dtype="int64") + speculate_insert_first_token( + token_id, + accept_tokens, + next_tokens, + cu_next_token_offset, + cu_batch_token_offset, + seq_lens_this_time, + seq_lens_encoder, + ) + + gold_token_id = paddle.to_tensor([1001, 2001, 1003, 2002, 1005, 2003], dtype="int64") + + assert paddle.equal_all(token_id, gold_token_id) + + +if __name__ == "__main__": + unittest.main()