mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-11-03 02:53:26 +08:00
support mtp in v1_scheduler mode (#3695)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
239
tests/operators/test_speculative_schedule_cache.py
Normal file
239
tests/operators/test_speculative_schedule_cache.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import speculate_schedule_cache
|
||||
|
||||
|
||||
def cpu_reference(
|
||||
draft_tokens,
|
||||
block_tables,
|
||||
stop_flags,
|
||||
seq_lens_this_time,
|
||||
seq_lens_decoder,
|
||||
step_seq_lens_decoder,
|
||||
step_draft_tokens,
|
||||
step_seq_lens_this_time,
|
||||
accept_num,
|
||||
accept_tokens,
|
||||
is_block_step,
|
||||
not_need_stop,
|
||||
stop_nums,
|
||||
block_size,
|
||||
max_draft_tokens,
|
||||
):
|
||||
"""Pure-NumPy mirror of the CUDA kernel's logic (single block of 512 threads).
|
||||
|
||||
Shapes are the same as inputs to the custom op. This mutates the provided
|
||||
NumPy arrays in-place, exactly like the kernel does.
|
||||
"""
|
||||
real_bsz = seq_lens_this_time.shape[0]
|
||||
max_bsz = stop_flags.shape[0]
|
||||
draft_tokens_len = draft_tokens.shape[1]
|
||||
block_num_per_seq = block_tables.shape[1]
|
||||
|
||||
max_next_step_tokens = 2 * max_draft_tokens + 2
|
||||
|
||||
# Block-local reduction input per thread (threadIdx.x -> bid)
|
||||
stop_flag_now_int = np.zeros(512, dtype=np.int64) # THREADBLOCK_SIZE = 512
|
||||
|
||||
for bid in range(512):
|
||||
if bid < real_bsz:
|
||||
if not stop_flags[bid]:
|
||||
max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) // block_size
|
||||
if max_possible_block_idx < block_num_per_seq and block_tables[bid, max_possible_block_idx] == -1:
|
||||
is_block_step[bid] = True
|
||||
step_seq_lens_this_time[bid] = seq_lens_this_time[bid]
|
||||
seq_lens_this_time[bid] = 0
|
||||
stop_flags[bid] = True
|
||||
step_seq_lens_decoder[bid] = seq_lens_decoder[bid]
|
||||
seq_lens_decoder[bid] = 0
|
||||
accept_num[bid] = 0
|
||||
accept_tokens[bid, :] = -1
|
||||
step_draft_tokens[bid, :draft_tokens_len] = draft_tokens[bid, :draft_tokens_len]
|
||||
stop_flag_now_int[bid] = 1
|
||||
else:
|
||||
stop_flag_now_int[bid] = 0
|
||||
else:
|
||||
stop_flag_now_int[bid] = 1
|
||||
elif bid < max_bsz:
|
||||
# Threads in [real_bsz, max_bsz) contribute 1 to reduction
|
||||
stop_flag_now_int[bid] = 1
|
||||
else:
|
||||
stop_flag_now_int[bid] = 0
|
||||
|
||||
stop_sum = int(stop_flag_now_int.sum())
|
||||
not_need_stop[0] = stop_sum < int(stop_nums[0])
|
||||
|
||||
|
||||
class TestSpeculateScheduleCache(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
if not paddle.is_compiled_with_cuda():
|
||||
raise unittest.SkipTest("Paddle is not compiled with CUDA; skipping GPU op test.")
|
||||
paddle.device.set_device("gpu")
|
||||
|
||||
def setUp(self):
|
||||
# --- Construct a deterministic case that exercises all branches ---
|
||||
# real_bsz < max_bsz to test the padding logic in the CUB reduction
|
||||
self.real_bsz = 3
|
||||
self.max_bsz = 5 # only stop_flags has length max_bsz
|
||||
|
||||
self.draft_tokens_len = 6
|
||||
self.accept_tokens_len = 5
|
||||
self.block_size = 4
|
||||
self.block_num_per_seq = 3
|
||||
self.max_draft_tokens = 2 # -> max_next_step_tokens = 6
|
||||
|
||||
# Inputs that will trigger for bid 0, not trigger for bid 2, and bid 1 is already stopped
|
||||
# seq_lens_decoder + 6 // 4 -> indices: [1, 1, 4]. Index 4 is out of range -> no trigger on bid 2
|
||||
self.draft_tokens = paddle.to_tensor(
|
||||
np.array(
|
||||
[
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[2, 2, 2, 2, 2, 2],
|
||||
[3, 3, 3, 3, 3, 3],
|
||||
],
|
||||
dtype=np.int64,
|
||||
)
|
||||
)
|
||||
self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32))
|
||||
# stop_flags length is max_bsz, others are real_bsz
|
||||
self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_))
|
||||
self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32))
|
||||
self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32))
|
||||
|
||||
# Will be filled by kernel for the triggering bids only
|
||||
self.step_seq_lens_decoder = paddle.zeros((self.real_bsz,), dtype="int32")
|
||||
self.step_draft_tokens = paddle.zeros((self.real_bsz, self.draft_tokens_len), dtype="int64")
|
||||
self.step_seq_lens_this_time = paddle.zeros((self.real_bsz,), dtype="int32")
|
||||
|
||||
# Intentionally non-zero so we can verify in-place zeroing only where triggered
|
||||
self.accept_num = paddle.to_tensor(np.array([9, 8, 7], dtype=np.int32))
|
||||
self.accept_tokens = paddle.to_tensor(
|
||||
np.arange(self.real_bsz * self.accept_tokens_len, dtype=np.int64).reshape(
|
||||
self.real_bsz, self.accept_tokens_len
|
||||
)
|
||||
)
|
||||
self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool)
|
||||
|
||||
# not_need_stop lives on CPU in the caller; the kernel copies to device internally
|
||||
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
|
||||
|
||||
# Choose threshold so with: bid0 triggers, bid1 already stopped, padding (5-3)=2 -> stop_sum = 1+1+2 = 4
|
||||
# Set stop_nums to 5 so not_need_stop = (4 < 5) = True
|
||||
self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64)
|
||||
|
||||
# Keep NumPy copies for CPU reference
|
||||
self.np_draft_tokens = self.draft_tokens.numpy().copy()
|
||||
self.np_block_tables = self.block_tables.numpy().copy()
|
||||
self.np_stop_flags = self.stop_flags.numpy().copy()
|
||||
self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy()
|
||||
self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy()
|
||||
self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy()
|
||||
self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy()
|
||||
self.np_step_seq_lens_this_time = self.step_seq_lens_this_time.numpy().copy()
|
||||
self.np_accept_num = self.accept_num.numpy().copy()
|
||||
self.np_accept_tokens = self.accept_tokens.numpy().copy()
|
||||
self.np_is_block_step = self.is_block_step.numpy().copy()
|
||||
self.np_not_need_stop = self.not_need_stop.numpy().copy()
|
||||
self.np_stop_nums = self.stop_nums.numpy().copy()
|
||||
|
||||
def test_correctness_against_cpu_reference(self):
|
||||
# Run GPU kernel (in-place)
|
||||
speculate_schedule_cache(
|
||||
self.draft_tokens,
|
||||
self.block_tables,
|
||||
self.stop_flags,
|
||||
self.seq_lens_this_time,
|
||||
self.seq_lens_decoder,
|
||||
self.step_seq_lens_decoder,
|
||||
self.step_draft_tokens,
|
||||
self.step_seq_lens_this_time,
|
||||
self.accept_num,
|
||||
self.accept_tokens,
|
||||
self.is_block_step,
|
||||
self.not_need_stop,
|
||||
self.stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
# Compute CPU reference (in-place on NumPy copies)
|
||||
cpu_reference(
|
||||
self.np_draft_tokens,
|
||||
self.np_block_tables,
|
||||
self.np_stop_flags,
|
||||
self.np_seq_lens_this_time,
|
||||
self.np_seq_lens_decoder,
|
||||
self.np_step_seq_lens_decoder,
|
||||
self.np_step_draft_tokens,
|
||||
self.np_step_seq_lens_this_time,
|
||||
self.np_accept_num,
|
||||
self.np_accept_tokens,
|
||||
self.np_is_block_step,
|
||||
self.np_not_need_stop,
|
||||
self.np_stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
# Compare all mutated tensors
|
||||
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), self.np_step_draft_tokens)
|
||||
np.testing.assert_array_equal(self.accept_tokens.numpy(), self.np_accept_tokens)
|
||||
np.testing.assert_array_equal(self.stop_flags.numpy(), self.np_stop_flags)
|
||||
np.testing.assert_array_equal(self.is_block_step.numpy(), self.np_is_block_step)
|
||||
np.testing.assert_array_equal(self.seq_lens_this_time.numpy(), self.np_seq_lens_this_time)
|
||||
np.testing.assert_array_equal(self.seq_lens_decoder.numpy(), self.np_seq_lens_decoder)
|
||||
np.testing.assert_array_equal(self.step_seq_lens_decoder.numpy(), self.np_step_seq_lens_decoder)
|
||||
np.testing.assert_array_equal(self.step_seq_lens_this_time.numpy(), self.np_step_seq_lens_this_time)
|
||||
np.testing.assert_array_equal(self.accept_num.numpy(), self.np_accept_num)
|
||||
self.assertEqual(bool(self.not_need_stop.numpy()[0]), bool(self.np_not_need_stop[0]))
|
||||
|
||||
def test_no_trigger_path(self):
|
||||
# Make block_tables at candidate index != -1 so nothing triggers
|
||||
# Candidate index for bid 0/1 is 1, set it to 7
|
||||
bt = self.block_tables.numpy()
|
||||
bt[:, 1] = 7
|
||||
self.block_tables = paddle.to_tensor(bt)
|
||||
|
||||
# Reset outputs to distinctive values
|
||||
self.step_seq_lens_decoder[:] = 0
|
||||
self.step_draft_tokens[:] = 0
|
||||
self.step_seq_lens_this_time[:] = 0
|
||||
self.accept_num[:] = -123
|
||||
self.accept_tokens[:] = -777
|
||||
self.is_block_step[:] = False
|
||||
self.not_need_stop[:] = False
|
||||
|
||||
# For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3
|
||||
# With stop_nums=5 -> True
|
||||
speculate_schedule_cache(
|
||||
self.draft_tokens,
|
||||
self.block_tables,
|
||||
self.stop_flags,
|
||||
self.seq_lens_this_time,
|
||||
self.seq_lens_decoder,
|
||||
self.step_seq_lens_decoder,
|
||||
self.step_draft_tokens,
|
||||
self.step_seq_lens_this_time,
|
||||
self.accept_num,
|
||||
self.accept_tokens,
|
||||
self.is_block_step,
|
||||
self.not_need_stop,
|
||||
self.stop_nums,
|
||||
self.block_size,
|
||||
self.max_draft_tokens,
|
||||
)
|
||||
|
||||
# Nothing should have changed except not_need_stop
|
||||
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), np.zeros_like(self.step_draft_tokens.numpy()))
|
||||
np.testing.assert_array_equal(self.is_block_step.numpy(), np.zeros_like(self.is_block_step.numpy()))
|
||||
np.testing.assert_array_equal(self.accept_tokens.numpy(), np.full_like(self.accept_tokens.numpy(), -777))
|
||||
np.testing.assert_array_equal(self.accept_num.numpy(), np.full_like(self.accept_num.numpy(), -123))
|
||||
self.assertTrue(bool(self.not_need_stop.numpy()[0]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user