Add ut for speculative sampler (#4650)

This commit is contained in:
GoldPancake
2025-10-30 10:37:49 +08:00
committed by GitHub
parent 1712e1351b
commit fddda50cb9
4 changed files with 478 additions and 11 deletions

View File

@@ -480,7 +480,7 @@ class SpeculativeSampler(nn.Layer):
share_inputs = sampling_metadata.share_inputs
last_logits = logits
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
batch_token_num = share_inputs["batch_token_num"][:real_bsz]
batch_token_num = share_inputs["accept_num"][:real_bsz]
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
@@ -637,7 +637,7 @@ class SpeculativeSampler(nn.Layer):
batch_token_num = paddle.where(
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
share_inputs["accept_num"][:real_bsz].unsqueeze(1),
share_inputs["seq_lens_this_time"],
).squeeze(1)
share_inputs["batch_token_num"] = batch_token_num
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
@@ -647,11 +647,11 @@ class SpeculativeSampler(nn.Layer):
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
).astype("int32")
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
target_logtis = paddle.empty(
target_logits = paddle.empty(
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
)
speculate_get_target_logits(
target_logtis,
target_logits,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
@@ -660,25 +660,22 @@ class SpeculativeSampler(nn.Layer):
share_inputs["accept_num"],
)
if self.logprobs_mode == "raw_logprobs":
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata)
elif self.logprobs_mode == "raw_logits":
raw_logprobs = target_logtis.clone()
raw_logprobs = target_logits.clone()
logprobs_tensors = None
token_ids = share_inputs["accept_tokens"]
if num_logprobs is not None:
token_ids = paddle.concat(
[
share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]]
for i in range(share_inputs["accept_num"][:real_bsz].shape[0])
]
[share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)]
)
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
sampler_output = SamplerOutput(
sampled_token_ids=token_ids,
logprobs_tensors=logprobs_tensors,
token_num_per_batch=batch_token_num,
token_num_per_batch=share_inputs["accept_num"],
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
)

View File

@@ -0,0 +1,226 @@
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from unittest.mock import Mock
import paddle
from fastdeploy.config import (
CacheConfig,
FDConfig,
GraphOptimizationConfig,
ParallelConfig,
SchedulerConfig,
SpeculativeConfig,
)
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
from fastdeploy.model_executor.layers.sample.sampler import (
MTPSampler,
SpeculativeSampler,
)
def _create_fake_logits(batch_size: int, vocab_size: int) -> paddle.Tensor:
fake_logits = paddle.rand(shape=[batch_size, vocab_size], dtype="float32")
return fake_logits
def _create_penalty_tensor(batch_size: int, penalty_value: float) -> paddle.Tensor:
return paddle.full(shape=[batch_size, 1], fill_value=penalty_value, dtype="float32")
def _create_tokens_tensor(
batch_size: int,
max_seq_len: int,
) -> paddle.Tensor:
pre_token_ids = paddle.full(shape=[batch_size, max_seq_len], fill_value=-1, dtype="int64")
return pre_token_ids
def _create_default_sampling_metadata(
batch_size: int,
min_seq_len: int,
max_seq_len: int,
max_num_logprobs: int = None,
) -> SamplingMetadata:
fake_sampling_metadata = SamplingMetadata(
temperature=paddle.full(shape=[batch_size, 1], fill_value=0.9, dtype="float32"),
top_p=paddle.full(shape=[batch_size, 1], fill_value=0.7, dtype="float32"),
prompt_ids=paddle.full(shape=[batch_size, max_seq_len], fill_value=0, dtype="int64"),
prompt_lens=paddle.full(shape=[batch_size, 1], fill_value=5, dtype="int64"),
step_idx=paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64"),
pre_token_ids=_create_tokens_tensor(batch_size, max_seq_len),
frequency_penalties=_create_penalty_tensor(batch_size, 0.0),
presence_penalties=_create_penalty_tensor(batch_size, 0.0),
repetition_penalties=_create_penalty_tensor(batch_size, 1.0),
min_dec_lens=paddle.full(shape=[batch_size, 1], fill_value=min_seq_len, dtype="int64"),
bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"),
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
min_p=paddle.randn([batch_size]),
seed=paddle.to_tensor([[2025]]),
)
if max_num_logprobs is not None:
fake_sampling_metadata.max_num_logprobs = max_num_logprobs
return fake_sampling_metadata
def _create_fd_config(max_model_len):
model_config: Mock = Mock()
model_config.max_model_len = max_model_len
speculative_config = SpeculativeConfig({})
graph_opt_config = GraphOptimizationConfig({})
scheduler_config = SchedulerConfig({})
parallel_config = ParallelConfig({})
cache_config = CacheConfig({})
cache_config.cache_transfer_protocol = "rdma,ipc"
cache_config.pd_comm_port = "2334"
fd_config = FDConfig(
model_config=model_config,
speculative_config=speculative_config,
graph_opt_config=graph_opt_config,
scheduler_config=scheduler_config,
parallel_config=parallel_config,
cache_config=cache_config,
)
return fd_config
def _create_share_inputs(max_num_seqs, max_draft_token_num, max_model_len, vocab_size):
share_inputs = {}
share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 2, dtype="int32")
share_inputs["output_cum_offsets"] = paddle.concat(
[(max_model_len - share_inputs["seq_lens_this_time"][i]) * i for i in range(max_num_seqs)]
)
share_inputs["output_padding_offset"] = paddle.repeat_interleave(share_inputs["output_cum_offsets"], 2)
share_inputs["accept_tokens"] = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64"
)
share_inputs["accept_num"] = paddle.full(shape=[max_num_seqs], fill_value=1, dtype="int32")
share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 1, dtype="int64")
share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], False, dtype="bool")
share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 2, dtype="int32")
share_inputs["draft_tokens"] = paddle.full(
shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64"
)
share_inputs["max_dec_len"] = paddle.full([max_num_seqs, 1], max_model_len, dtype="int64")
share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
share_inputs["actual_draft_token_num"] = paddle.full(
shape=[max_num_seqs], fill_value=max_draft_token_num, dtype="int32"
)
share_inputs["batch_token_num"] = paddle.where(
share_inputs["seq_lens_encoder"] != 0,
paddle.ones_like(share_inputs["seq_lens_encoder"]),
share_inputs["seq_lens_this_time"],
).squeeze(1)
share_inputs["next_token_num"] = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32")
share_inputs["cu_batch_token_offset"] = paddle.concat(
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"])]
).astype("int32")
share_inputs["cu_next_token_offset"] = paddle.full(shape=[max_num_seqs + 1], fill_value=0, dtype="int32")
share_inputs["substep"] = 0
share_inputs["draft_logits"] = paddle.full(
[max_num_seqs * (max_draft_token_num + 1), vocab_size], -1, dtype="float32"
)
return share_inputs
def test_speculative_sampler():
batch_size = 32
vocab_size = 1024
min_seq_len = 1
max_seq_len = 1024
max_model_len = 1024
max_draft_token_num = 1
fd_config = _create_fd_config(max_model_len)
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
sampler = SpeculativeSampler(fd_config)
sampler(logits, sampling_metadata, max_model_len, share_inputs)
def test_speculative_sampler_logprobs():
batch_size = 32
vocab_size = 1024
min_seq_len = 1
max_seq_len = 1024
max_model_len = 1024
max_draft_token_num = 1
fd_config = _create_fd_config(max_model_len)
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
sampling_metadata.share_inputs = share_inputs
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
logprobs_mode_list = ["raw_logprobs", "raw_logits"]
for logprobs_mode in logprobs_mode_list:
fd_config.model_config.logprobs_mode = logprobs_mode
sampler = SpeculativeSampler(fd_config)
sampler(logits, sampling_metadata, max_model_len, share_inputs)
def test_mtp_sampler():
batch_size = 32
vocab_size = 1024
min_seq_len = 1
max_seq_len = 1024
max_model_len = 1024
max_draft_token_num = 1
fd_config = _create_fd_config(max_model_len)
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
sampler = MTPSampler(fd_config)
sampler(logits, sampling_metadata, max_model_len, share_inputs)
def test_mtp_sampler_logprobs():
batch_size = 32
vocab_size = 1024
min_seq_len = 1
max_seq_len = 1024
max_model_len = 1024
max_draft_token_num = 1
fd_config = _create_fd_config(max_model_len)
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
sampling_metadata.share_inputs = share_inputs
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
logprobs_mode_list = ["raw_logprobs", "raw_logits"]
for logprobs_mode in logprobs_mode_list:
fd_config.model_config.logprobs_mode = logprobs_mode
sampler = MTPSampler(fd_config)
sampler(logits, sampling_metadata, max_model_len, share_inputs)
if __name__ == "__main__":
test_speculative_sampler()
test_speculative_sampler_logprobs()
test_mtp_sampler()
test_mtp_sampler_logprobs()

View File

@@ -0,0 +1,144 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from fastdeploy.model_executor.layers.sample.ops.speculate_logprob_utils import (
speculate_get_target_logits,
)
class TestSpeculateInsertFirstToken(unittest.TestCase):
def setUp(self):
self.vocab_size = 8192
def test_all_decode(self):
token_num = 6
logits = paddle.full(shape=[token_num, self.vocab_size], fill_value=-1, dtype="float32")
for i in range(token_num):
logits[i][:] = i
seq_lens_encoder = paddle.to_tensor([[0], [0], [0]], dtype="int32")
seq_lens_this_time = paddle.to_tensor([[2], [2], [2]], dtype="int32")
accept_num = paddle.to_tensor([1, 2, 1], dtype="int32")
batch_token_num = paddle.where(
seq_lens_encoder != 0,
paddle.ones_like(seq_lens_encoder),
seq_lens_this_time,
).squeeze(1)
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
"int32"
)
cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32")
target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype)
speculate_get_target_logits(
target_logits,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
seq_lens_this_time,
seq_lens_encoder,
accept_num,
)
glod_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32")
glod_logits[0][:] = 0
glod_logits[1][:] = 2
glod_logits[2][:] = 3
glod_logits[3][:] = 4
assert paddle.allclose(target_logits, glod_logits)
def test_partial_decode(self):
token_num = 5
logits = paddle.full(shape=[token_num, self.vocab_size], fill_value=-1, dtype="float32")
for i in range(token_num):
logits[i][:] = i
seq_lens_encoder = paddle.to_tensor([[10], [0], [0]], dtype="int32")
seq_lens_this_time = paddle.to_tensor([[10], [2], [2]], dtype="int32")
accept_num = paddle.to_tensor([1, 2, 1], dtype="int32")
batch_token_num = paddle.where(
seq_lens_encoder != 0,
paddle.ones_like(seq_lens_encoder),
seq_lens_this_time,
).squeeze(1)
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
"int32"
)
cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32")
target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype)
speculate_get_target_logits(
target_logits,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
seq_lens_this_time,
seq_lens_encoder,
accept_num,
)
glod_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32")
glod_logits[0][:] = 0
glod_logits[1][:] = 1
glod_logits[2][:] = 2
glod_logits[3][:] = 3
assert paddle.allclose(target_logits, glod_logits)
def test_all_prefill(self):
token_num = 3
logits = paddle.full(shape=[token_num, self.vocab_size], fill_value=-1, dtype="float32")
for i in range(token_num):
logits[i][:] = i
seq_lens_encoder = paddle.to_tensor([[10], [10], [10]], dtype="int32")
seq_lens_this_time = paddle.to_tensor([[10], [10], [10]], dtype="int32")
accept_num = paddle.to_tensor([1, 1, 1], dtype="int32")
batch_token_num = paddle.where(
seq_lens_encoder != 0,
paddle.ones_like(seq_lens_encoder),
seq_lens_this_time,
).squeeze(1)
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
"int32"
)
cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32")
target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype)
speculate_get_target_logits(
target_logits,
logits,
cu_batch_token_offset,
ori_cu_batch_token_offset,
seq_lens_this_time,
seq_lens_encoder,
accept_num,
)
glod_logits = paddle.full(shape=[3, self.vocab_size], fill_value=-1, dtype="float32")
glod_logits[0][:] = 0
glod_logits[1][:] = 1
glod_logits[2][:] = 2
assert paddle.allclose(target_logits, glod_logits)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,100 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
from fastdeploy.model_executor.layers.sample.ops.speculate_logprob_utils import (
speculate_insert_first_token,
)
class TestSpeculateInsertFirstToken(unittest.TestCase):
def test_all_decode(self):
token_num = 5
accept_tokens = paddle.to_tensor([[1001, 1002], [1003, 1004], [1005, 1006]], dtype="int64")
next_tokens = paddle.to_tensor([[2001], [2002], [2003], [2004], [2005]], dtype="int64")
cu_next_token_offset = paddle.to_tensor([0, 2, 3, 5], dtype="int32")
cu_batch_token_offset = paddle.to_tensor([0, 2, 3, 5], dtype="int32")
seq_lens_this_time = paddle.to_tensor([[2], [1], [2]], dtype="int32")
seq_lens_encoder = paddle.to_tensor([[0], [0], [0]], dtype="int32")
token_id = paddle.empty(token_num, dtype="int64")
speculate_insert_first_token(
token_id,
accept_tokens,
next_tokens,
cu_next_token_offset,
cu_batch_token_offset,
seq_lens_this_time,
seq_lens_encoder,
)
gold_token_id = paddle.to_tensor([2001, 2002, 2003, 2004, 2005], dtype="int64")
assert paddle.equal_all(token_id, gold_token_id)
def test_partial_decode(self):
token_num = 6
accept_tokens = paddle.to_tensor([[1001, 1002], [1003, 1004], [1005, 1006]], dtype="int64")
next_tokens = paddle.to_tensor([[2001], [2002], [2003], [2004], [2005]], dtype="int64")
cu_next_token_offset = paddle.to_tensor([0, 2, 3, 5], dtype="int32")
cu_batch_token_offset = paddle.to_tensor([0, 2, 4, 6], dtype="int32")
seq_lens_this_time = paddle.to_tensor([[2], [10], [2]], dtype="int32")
seq_lens_encoder = paddle.to_tensor([[0], [10], [0]], dtype="int32")
token_id = paddle.empty(token_num, dtype="int64")
speculate_insert_first_token(
token_id,
accept_tokens,
next_tokens,
cu_next_token_offset,
cu_batch_token_offset,
seq_lens_this_time,
seq_lens_encoder,
)
gold_token_id = paddle.to_tensor([2001, 2002, 1003, 2003, 2004, 2005], dtype="int64")
assert paddle.equal_all(token_id, gold_token_id)
def test_all_prefill(self):
token_num = 6
accept_tokens = paddle.to_tensor([[1001, 1002], [1003, 1004], [1005, 1006]], dtype="int64")
next_tokens = paddle.to_tensor([[2001], [2002], [2003]], dtype="int64")
cu_next_token_offset = paddle.to_tensor([0, 1, 2, 3], dtype="int32")
cu_batch_token_offset = paddle.to_tensor([0, 2, 4, 6], dtype="int32")
seq_lens_this_time = paddle.to_tensor([[10], [10], [10]], dtype="int32")
seq_lens_encoder = paddle.to_tensor([[10], [10], [10]], dtype="int32")
token_id = paddle.empty(token_num, dtype="int64")
speculate_insert_first_token(
token_id,
accept_tokens,
next_tokens,
cu_next_token_offset,
cu_batch_token_offset,
seq_lens_this_time,
seq_lens_encoder,
)
gold_token_id = paddle.to_tensor([1001, 2001, 1003, 2002, 1005, 2003], dtype="int64")
assert paddle.equal_all(token_id, gold_token_id)
if __name__ == "__main__":
unittest.main()