mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
Add ut for speculative sampler (#4650)
This commit is contained in:
@@ -480,7 +480,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs = sampling_metadata.share_inputs
|
||||
last_logits = logits
|
||||
real_bsz = share_inputs["seq_lens_this_time"].shape[0]
|
||||
batch_token_num = share_inputs["batch_token_num"][:real_bsz]
|
||||
batch_token_num = share_inputs["accept_num"][:real_bsz]
|
||||
|
||||
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
|
||||
top_p_normalized_logprobs = sampling_metadata.top_p_normalized_logprobs
|
||||
@@ -637,7 +637,7 @@ class SpeculativeSampler(nn.Layer):
|
||||
batch_token_num = paddle.where(
|
||||
share_inputs["seq_lens_encoder"][:real_bsz] != 0,
|
||||
paddle.ones_like(share_inputs["seq_lens_encoder"][:real_bsz]),
|
||||
share_inputs["accept_num"][:real_bsz].unsqueeze(1),
|
||||
share_inputs["seq_lens_this_time"],
|
||||
).squeeze(1)
|
||||
share_inputs["batch_token_num"] = batch_token_num
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
@@ -647,11 +647,11 @@ class SpeculativeSampler(nn.Layer):
|
||||
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"][:real_bsz])]
|
||||
).astype("int32")
|
||||
share_inputs["cu_batch_token_offset"] = cu_batch_token_offset
|
||||
target_logtis = paddle.empty(
|
||||
target_logits = paddle.empty(
|
||||
[share_inputs["accept_num"][:real_bsz].sum(), logits.shape[1]], dtype=logits.dtype
|
||||
)
|
||||
speculate_get_target_logits(
|
||||
target_logtis,
|
||||
target_logits,
|
||||
logits,
|
||||
cu_batch_token_offset,
|
||||
ori_cu_batch_token_offset,
|
||||
@@ -660,25 +660,22 @@ class SpeculativeSampler(nn.Layer):
|
||||
share_inputs["accept_num"],
|
||||
)
|
||||
if self.logprobs_mode == "raw_logprobs":
|
||||
raw_logprobs = self.compute_logprobs(target_logtis, sampling_metadata)
|
||||
raw_logprobs = self.compute_logprobs(target_logits, sampling_metadata)
|
||||
elif self.logprobs_mode == "raw_logits":
|
||||
raw_logprobs = target_logtis.clone()
|
||||
raw_logprobs = target_logits.clone()
|
||||
|
||||
logprobs_tensors = None
|
||||
token_ids = share_inputs["accept_tokens"]
|
||||
if num_logprobs is not None:
|
||||
token_ids = paddle.concat(
|
||||
[
|
||||
share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]]
|
||||
for i in range(share_inputs["accept_num"][:real_bsz].shape[0])
|
||||
]
|
||||
[share_inputs["accept_tokens"][i, : share_inputs["accept_num"][i]] for i in range(real_bsz)]
|
||||
)
|
||||
logprobs_tensors = self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=token_ids)
|
||||
|
||||
sampler_output = SamplerOutput(
|
||||
sampled_token_ids=token_ids,
|
||||
logprobs_tensors=logprobs_tensors,
|
||||
token_num_per_batch=batch_token_num,
|
||||
token_num_per_batch=share_inputs["accept_num"],
|
||||
cu_batch_token_offset=share_inputs["cu_batch_token_offset"],
|
||||
)
|
||||
|
||||
|
||||
226
tests/layers/test_speculative_sampler.py
Normal file
226
tests/layers/test_speculative_sampler.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import (
|
||||
CacheConfig,
|
||||
FDConfig,
|
||||
GraphOptimizationConfig,
|
||||
ParallelConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.sample.meta_data import SamplingMetadata
|
||||
from fastdeploy.model_executor.layers.sample.sampler import (
|
||||
MTPSampler,
|
||||
SpeculativeSampler,
|
||||
)
|
||||
|
||||
|
||||
def _create_fake_logits(batch_size: int, vocab_size: int) -> paddle.Tensor:
|
||||
fake_logits = paddle.rand(shape=[batch_size, vocab_size], dtype="float32")
|
||||
return fake_logits
|
||||
|
||||
|
||||
def _create_penalty_tensor(batch_size: int, penalty_value: float) -> paddle.Tensor:
|
||||
return paddle.full(shape=[batch_size, 1], fill_value=penalty_value, dtype="float32")
|
||||
|
||||
|
||||
def _create_tokens_tensor(
|
||||
batch_size: int,
|
||||
max_seq_len: int,
|
||||
) -> paddle.Tensor:
|
||||
pre_token_ids = paddle.full(shape=[batch_size, max_seq_len], fill_value=-1, dtype="int64")
|
||||
return pre_token_ids
|
||||
|
||||
|
||||
def _create_default_sampling_metadata(
|
||||
batch_size: int,
|
||||
min_seq_len: int,
|
||||
max_seq_len: int,
|
||||
max_num_logprobs: int = None,
|
||||
) -> SamplingMetadata:
|
||||
|
||||
fake_sampling_metadata = SamplingMetadata(
|
||||
temperature=paddle.full(shape=[batch_size, 1], fill_value=0.9, dtype="float32"),
|
||||
top_p=paddle.full(shape=[batch_size, 1], fill_value=0.7, dtype="float32"),
|
||||
prompt_ids=paddle.full(shape=[batch_size, max_seq_len], fill_value=0, dtype="int64"),
|
||||
prompt_lens=paddle.full(shape=[batch_size, 1], fill_value=5, dtype="int64"),
|
||||
step_idx=paddle.full(shape=[batch_size, 1], fill_value=0, dtype="int64"),
|
||||
pre_token_ids=_create_tokens_tensor(batch_size, max_seq_len),
|
||||
frequency_penalties=_create_penalty_tensor(batch_size, 0.0),
|
||||
presence_penalties=_create_penalty_tensor(batch_size, 0.0),
|
||||
repetition_penalties=_create_penalty_tensor(batch_size, 1.0),
|
||||
min_dec_lens=paddle.full(shape=[batch_size, 1], fill_value=min_seq_len, dtype="int64"),
|
||||
bad_words_token_ids=paddle.full(shape=[batch_size], fill_value=-1, dtype="int64"),
|
||||
eos_token_ids=paddle.full(shape=[batch_size], fill_value=-2, dtype="int64"),
|
||||
min_p=paddle.randn([batch_size]),
|
||||
seed=paddle.to_tensor([[2025]]),
|
||||
)
|
||||
if max_num_logprobs is not None:
|
||||
fake_sampling_metadata.max_num_logprobs = max_num_logprobs
|
||||
return fake_sampling_metadata
|
||||
|
||||
|
||||
def _create_fd_config(max_model_len):
|
||||
model_config: Mock = Mock()
|
||||
model_config.max_model_len = max_model_len
|
||||
speculative_config = SpeculativeConfig({})
|
||||
graph_opt_config = GraphOptimizationConfig({})
|
||||
scheduler_config = SchedulerConfig({})
|
||||
parallel_config = ParallelConfig({})
|
||||
cache_config = CacheConfig({})
|
||||
cache_config.cache_transfer_protocol = "rdma,ipc"
|
||||
cache_config.pd_comm_port = "2334"
|
||||
fd_config = FDConfig(
|
||||
model_config=model_config,
|
||||
speculative_config=speculative_config,
|
||||
graph_opt_config=graph_opt_config,
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
|
||||
return fd_config
|
||||
|
||||
|
||||
def _create_share_inputs(max_num_seqs, max_draft_token_num, max_model_len, vocab_size):
|
||||
share_inputs = {}
|
||||
share_inputs["seq_lens_this_time"] = paddle.full([max_num_seqs, 1], 2, dtype="int32")
|
||||
share_inputs["output_cum_offsets"] = paddle.concat(
|
||||
[(max_model_len - share_inputs["seq_lens_this_time"][i]) * i for i in range(max_num_seqs)]
|
||||
)
|
||||
share_inputs["output_padding_offset"] = paddle.repeat_interleave(share_inputs["output_cum_offsets"], 2)
|
||||
share_inputs["accept_tokens"] = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64"
|
||||
)
|
||||
share_inputs["accept_num"] = paddle.full(shape=[max_num_seqs], fill_value=1, dtype="int32")
|
||||
share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 1, dtype="int64")
|
||||
share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], False, dtype="bool")
|
||||
share_inputs["seq_lens_encoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
share_inputs["seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 2, dtype="int32")
|
||||
share_inputs["draft_tokens"] = paddle.full(
|
||||
shape=[max_num_seqs, max_draft_token_num + 1], fill_value=0, dtype="int64"
|
||||
)
|
||||
share_inputs["max_dec_len"] = paddle.full([max_num_seqs, 1], max_model_len, dtype="int64")
|
||||
share_inputs["is_block_step"] = paddle.full([max_num_seqs], False, dtype="bool")
|
||||
share_inputs["actual_draft_token_num"] = paddle.full(
|
||||
shape=[max_num_seqs], fill_value=max_draft_token_num, dtype="int32"
|
||||
)
|
||||
|
||||
share_inputs["batch_token_num"] = paddle.where(
|
||||
share_inputs["seq_lens_encoder"] != 0,
|
||||
paddle.ones_like(share_inputs["seq_lens_encoder"]),
|
||||
share_inputs["seq_lens_this_time"],
|
||||
).squeeze(1)
|
||||
share_inputs["next_token_num"] = paddle.full(shape=[max_num_seqs], fill_value=0, dtype="int32")
|
||||
share_inputs["cu_batch_token_offset"] = paddle.concat(
|
||||
[paddle.to_tensor([0]), paddle.cumsum(share_inputs["accept_num"])]
|
||||
).astype("int32")
|
||||
share_inputs["cu_next_token_offset"] = paddle.full(shape=[max_num_seqs + 1], fill_value=0, dtype="int32")
|
||||
share_inputs["substep"] = 0
|
||||
share_inputs["draft_logits"] = paddle.full(
|
||||
[max_num_seqs * (max_draft_token_num + 1), vocab_size], -1, dtype="float32"
|
||||
)
|
||||
|
||||
return share_inputs
|
||||
|
||||
|
||||
def test_speculative_sampler():
|
||||
batch_size = 32
|
||||
vocab_size = 1024
|
||||
min_seq_len = 1
|
||||
max_seq_len = 1024
|
||||
max_model_len = 1024
|
||||
max_draft_token_num = 1
|
||||
|
||||
fd_config = _create_fd_config(max_model_len)
|
||||
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
|
||||
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
|
||||
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
|
||||
|
||||
sampler = SpeculativeSampler(fd_config)
|
||||
sampler(logits, sampling_metadata, max_model_len, share_inputs)
|
||||
|
||||
|
||||
def test_speculative_sampler_logprobs():
|
||||
batch_size = 32
|
||||
vocab_size = 1024
|
||||
min_seq_len = 1
|
||||
max_seq_len = 1024
|
||||
max_model_len = 1024
|
||||
max_draft_token_num = 1
|
||||
|
||||
fd_config = _create_fd_config(max_model_len)
|
||||
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
|
||||
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
|
||||
sampling_metadata.share_inputs = share_inputs
|
||||
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
|
||||
|
||||
logprobs_mode_list = ["raw_logprobs", "raw_logits"]
|
||||
for logprobs_mode in logprobs_mode_list:
|
||||
fd_config.model_config.logprobs_mode = logprobs_mode
|
||||
sampler = SpeculativeSampler(fd_config)
|
||||
sampler(logits, sampling_metadata, max_model_len, share_inputs)
|
||||
|
||||
|
||||
def test_mtp_sampler():
|
||||
batch_size = 32
|
||||
vocab_size = 1024
|
||||
min_seq_len = 1
|
||||
max_seq_len = 1024
|
||||
max_model_len = 1024
|
||||
max_draft_token_num = 1
|
||||
|
||||
fd_config = _create_fd_config(max_model_len)
|
||||
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len)
|
||||
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
|
||||
|
||||
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
|
||||
|
||||
sampler = MTPSampler(fd_config)
|
||||
sampler(logits, sampling_metadata, max_model_len, share_inputs)
|
||||
|
||||
|
||||
def test_mtp_sampler_logprobs():
|
||||
batch_size = 32
|
||||
vocab_size = 1024
|
||||
min_seq_len = 1
|
||||
max_seq_len = 1024
|
||||
max_model_len = 1024
|
||||
max_draft_token_num = 1
|
||||
|
||||
fd_config = _create_fd_config(max_model_len)
|
||||
share_inputs = _create_share_inputs(batch_size, max_draft_token_num, max_model_len, vocab_size)
|
||||
sampling_metadata = _create_default_sampling_metadata(batch_size, min_seq_len, max_seq_len, max_num_logprobs=0)
|
||||
sampling_metadata.share_inputs = share_inputs
|
||||
logits = _create_fake_logits(batch_size * (max_draft_token_num + 1), vocab_size)
|
||||
|
||||
logprobs_mode_list = ["raw_logprobs", "raw_logits"]
|
||||
for logprobs_mode in logprobs_mode_list:
|
||||
fd_config.model_config.logprobs_mode = logprobs_mode
|
||||
sampler = MTPSampler(fd_config)
|
||||
sampler(logits, sampling_metadata, max_model_len, share_inputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_speculative_sampler()
|
||||
test_speculative_sampler_logprobs()
|
||||
test_mtp_sampler()
|
||||
test_mtp_sampler_logprobs()
|
||||
144
tests/operators/test_speculate_get_target_logits.py
Normal file
144
tests/operators/test_speculate_get_target_logits.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.sample.ops.speculate_logprob_utils import (
|
||||
speculate_get_target_logits,
|
||||
)
|
||||
|
||||
|
||||
class TestSpeculateInsertFirstToken(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.vocab_size = 8192
|
||||
|
||||
def test_all_decode(self):
|
||||
token_num = 6
|
||||
logits = paddle.full(shape=[token_num, self.vocab_size], fill_value=-1, dtype="float32")
|
||||
for i in range(token_num):
|
||||
logits[i][:] = i
|
||||
|
||||
seq_lens_encoder = paddle.to_tensor([[0], [0], [0]], dtype="int32")
|
||||
seq_lens_this_time = paddle.to_tensor([[2], [2], [2]], dtype="int32")
|
||||
accept_num = paddle.to_tensor([1, 2, 1], dtype="int32")
|
||||
batch_token_num = paddle.where(
|
||||
seq_lens_encoder != 0,
|
||||
paddle.ones_like(seq_lens_encoder),
|
||||
seq_lens_this_time,
|
||||
).squeeze(1)
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
"int32"
|
||||
)
|
||||
cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32")
|
||||
target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype)
|
||||
|
||||
speculate_get_target_logits(
|
||||
target_logits,
|
||||
logits,
|
||||
cu_batch_token_offset,
|
||||
ori_cu_batch_token_offset,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
accept_num,
|
||||
)
|
||||
|
||||
glod_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32")
|
||||
glod_logits[0][:] = 0
|
||||
glod_logits[1][:] = 2
|
||||
glod_logits[2][:] = 3
|
||||
glod_logits[3][:] = 4
|
||||
|
||||
assert paddle.allclose(target_logits, glod_logits)
|
||||
|
||||
def test_partial_decode(self):
|
||||
token_num = 5
|
||||
logits = paddle.full(shape=[token_num, self.vocab_size], fill_value=-1, dtype="float32")
|
||||
for i in range(token_num):
|
||||
logits[i][:] = i
|
||||
|
||||
seq_lens_encoder = paddle.to_tensor([[10], [0], [0]], dtype="int32")
|
||||
seq_lens_this_time = paddle.to_tensor([[10], [2], [2]], dtype="int32")
|
||||
accept_num = paddle.to_tensor([1, 2, 1], dtype="int32")
|
||||
batch_token_num = paddle.where(
|
||||
seq_lens_encoder != 0,
|
||||
paddle.ones_like(seq_lens_encoder),
|
||||
seq_lens_this_time,
|
||||
).squeeze(1)
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
"int32"
|
||||
)
|
||||
cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32")
|
||||
target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype)
|
||||
|
||||
speculate_get_target_logits(
|
||||
target_logits,
|
||||
logits,
|
||||
cu_batch_token_offset,
|
||||
ori_cu_batch_token_offset,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
accept_num,
|
||||
)
|
||||
|
||||
glod_logits = paddle.full(shape=[4, self.vocab_size], fill_value=-1, dtype="float32")
|
||||
glod_logits[0][:] = 0
|
||||
glod_logits[1][:] = 1
|
||||
glod_logits[2][:] = 2
|
||||
glod_logits[3][:] = 3
|
||||
|
||||
assert paddle.allclose(target_logits, glod_logits)
|
||||
|
||||
def test_all_prefill(self):
|
||||
token_num = 3
|
||||
logits = paddle.full(shape=[token_num, self.vocab_size], fill_value=-1, dtype="float32")
|
||||
for i in range(token_num):
|
||||
logits[i][:] = i
|
||||
|
||||
seq_lens_encoder = paddle.to_tensor([[10], [10], [10]], dtype="int32")
|
||||
seq_lens_this_time = paddle.to_tensor([[10], [10], [10]], dtype="int32")
|
||||
accept_num = paddle.to_tensor([1, 1, 1], dtype="int32")
|
||||
batch_token_num = paddle.where(
|
||||
seq_lens_encoder != 0,
|
||||
paddle.ones_like(seq_lens_encoder),
|
||||
seq_lens_this_time,
|
||||
).squeeze(1)
|
||||
ori_cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(batch_token_num)]).astype(
|
||||
"int32"
|
||||
)
|
||||
cu_batch_token_offset = paddle.concat([paddle.to_tensor([0]), paddle.cumsum(accept_num)]).astype("int32")
|
||||
target_logits = paddle.empty([accept_num.sum(), logits.shape[1]], dtype=logits.dtype)
|
||||
|
||||
speculate_get_target_logits(
|
||||
target_logits,
|
||||
logits,
|
||||
cu_batch_token_offset,
|
||||
ori_cu_batch_token_offset,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
accept_num,
|
||||
)
|
||||
|
||||
glod_logits = paddle.full(shape=[3, self.vocab_size], fill_value=-1, dtype="float32")
|
||||
glod_logits[0][:] = 0
|
||||
glod_logits[1][:] = 1
|
||||
glod_logits[2][:] = 2
|
||||
|
||||
assert paddle.allclose(target_logits, glod_logits)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
100
tests/operators/test_speculate_insert_first_token.py
Normal file
100
tests/operators/test_speculate_insert_first_token.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.layers.sample.ops.speculate_logprob_utils import (
|
||||
speculate_insert_first_token,
|
||||
)
|
||||
|
||||
|
||||
class TestSpeculateInsertFirstToken(unittest.TestCase):
|
||||
|
||||
def test_all_decode(self):
|
||||
token_num = 5
|
||||
accept_tokens = paddle.to_tensor([[1001, 1002], [1003, 1004], [1005, 1006]], dtype="int64")
|
||||
next_tokens = paddle.to_tensor([[2001], [2002], [2003], [2004], [2005]], dtype="int64")
|
||||
cu_next_token_offset = paddle.to_tensor([0, 2, 3, 5], dtype="int32")
|
||||
cu_batch_token_offset = paddle.to_tensor([0, 2, 3, 5], dtype="int32")
|
||||
seq_lens_this_time = paddle.to_tensor([[2], [1], [2]], dtype="int32")
|
||||
seq_lens_encoder = paddle.to_tensor([[0], [0], [0]], dtype="int32")
|
||||
|
||||
token_id = paddle.empty(token_num, dtype="int64")
|
||||
speculate_insert_first_token(
|
||||
token_id,
|
||||
accept_tokens,
|
||||
next_tokens,
|
||||
cu_next_token_offset,
|
||||
cu_batch_token_offset,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
)
|
||||
|
||||
gold_token_id = paddle.to_tensor([2001, 2002, 2003, 2004, 2005], dtype="int64")
|
||||
|
||||
assert paddle.equal_all(token_id, gold_token_id)
|
||||
|
||||
def test_partial_decode(self):
|
||||
token_num = 6
|
||||
accept_tokens = paddle.to_tensor([[1001, 1002], [1003, 1004], [1005, 1006]], dtype="int64")
|
||||
next_tokens = paddle.to_tensor([[2001], [2002], [2003], [2004], [2005]], dtype="int64")
|
||||
cu_next_token_offset = paddle.to_tensor([0, 2, 3, 5], dtype="int32")
|
||||
cu_batch_token_offset = paddle.to_tensor([0, 2, 4, 6], dtype="int32")
|
||||
seq_lens_this_time = paddle.to_tensor([[2], [10], [2]], dtype="int32")
|
||||
seq_lens_encoder = paddle.to_tensor([[0], [10], [0]], dtype="int32")
|
||||
|
||||
token_id = paddle.empty(token_num, dtype="int64")
|
||||
speculate_insert_first_token(
|
||||
token_id,
|
||||
accept_tokens,
|
||||
next_tokens,
|
||||
cu_next_token_offset,
|
||||
cu_batch_token_offset,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
)
|
||||
|
||||
gold_token_id = paddle.to_tensor([2001, 2002, 1003, 2003, 2004, 2005], dtype="int64")
|
||||
|
||||
assert paddle.equal_all(token_id, gold_token_id)
|
||||
|
||||
def test_all_prefill(self):
|
||||
token_num = 6
|
||||
accept_tokens = paddle.to_tensor([[1001, 1002], [1003, 1004], [1005, 1006]], dtype="int64")
|
||||
next_tokens = paddle.to_tensor([[2001], [2002], [2003]], dtype="int64")
|
||||
cu_next_token_offset = paddle.to_tensor([0, 1, 2, 3], dtype="int32")
|
||||
cu_batch_token_offset = paddle.to_tensor([0, 2, 4, 6], dtype="int32")
|
||||
seq_lens_this_time = paddle.to_tensor([[10], [10], [10]], dtype="int32")
|
||||
seq_lens_encoder = paddle.to_tensor([[10], [10], [10]], dtype="int32")
|
||||
|
||||
token_id = paddle.empty(token_num, dtype="int64")
|
||||
speculate_insert_first_token(
|
||||
token_id,
|
||||
accept_tokens,
|
||||
next_tokens,
|
||||
cu_next_token_offset,
|
||||
cu_batch_token_offset,
|
||||
seq_lens_this_time,
|
||||
seq_lens_encoder,
|
||||
)
|
||||
|
||||
gold_token_id = paddle.to_tensor([1001, 2001, 1003, 2002, 1005, 2003], dtype="int64")
|
||||
|
||||
assert paddle.equal_all(token_id, gold_token_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user