Files
FastDeploy/tests/operators/test_speculative_schedule_cache.py
freeliuzc 88d44a2c93
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
support mtp in v1_scheduler mode (#3695)
2025-09-04 17:39:59 +08:00

240 lines
10 KiB
Python

import unittest
import numpy as np
import paddle
from fastdeploy.model_executor.ops.gpu import speculate_schedule_cache
def cpu_reference(
draft_tokens,
block_tables,
stop_flags,
seq_lens_this_time,
seq_lens_decoder,
step_seq_lens_decoder,
step_draft_tokens,
step_seq_lens_this_time,
accept_num,
accept_tokens,
is_block_step,
not_need_stop,
stop_nums,
block_size,
max_draft_tokens,
):
"""Pure-NumPy mirror of the CUDA kernel's logic (single block of 512 threads).
Shapes are the same as inputs to the custom op. This mutates the provided
NumPy arrays in-place, exactly like the kernel does.
"""
real_bsz = seq_lens_this_time.shape[0]
max_bsz = stop_flags.shape[0]
draft_tokens_len = draft_tokens.shape[1]
block_num_per_seq = block_tables.shape[1]
max_next_step_tokens = 2 * max_draft_tokens + 2
# Block-local reduction input per thread (threadIdx.x -> bid)
stop_flag_now_int = np.zeros(512, dtype=np.int64) # THREADBLOCK_SIZE = 512
for bid in range(512):
if bid < real_bsz:
if not stop_flags[bid]:
max_possible_block_idx = (seq_lens_decoder[bid] + max_next_step_tokens) // block_size
if max_possible_block_idx < block_num_per_seq and block_tables[bid, max_possible_block_idx] == -1:
is_block_step[bid] = True
step_seq_lens_this_time[bid] = seq_lens_this_time[bid]
seq_lens_this_time[bid] = 0
stop_flags[bid] = True
step_seq_lens_decoder[bid] = seq_lens_decoder[bid]
seq_lens_decoder[bid] = 0
accept_num[bid] = 0
accept_tokens[bid, :] = -1
step_draft_tokens[bid, :draft_tokens_len] = draft_tokens[bid, :draft_tokens_len]
stop_flag_now_int[bid] = 1
else:
stop_flag_now_int[bid] = 0
else:
stop_flag_now_int[bid] = 1
elif bid < max_bsz:
# Threads in [real_bsz, max_bsz) contribute 1 to reduction
stop_flag_now_int[bid] = 1
else:
stop_flag_now_int[bid] = 0
stop_sum = int(stop_flag_now_int.sum())
not_need_stop[0] = stop_sum < int(stop_nums[0])
class TestSpeculateScheduleCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
if not paddle.is_compiled_with_cuda():
raise unittest.SkipTest("Paddle is not compiled with CUDA; skipping GPU op test.")
paddle.device.set_device("gpu")
def setUp(self):
# --- Construct a deterministic case that exercises all branches ---
# real_bsz < max_bsz to test the padding logic in the CUB reduction
self.real_bsz = 3
self.max_bsz = 5 # only stop_flags has length max_bsz
self.draft_tokens_len = 6
self.accept_tokens_len = 5
self.block_size = 4
self.block_num_per_seq = 3
self.max_draft_tokens = 2 # -> max_next_step_tokens = 6
# Inputs that will trigger for bid 0, not trigger for bid 2, and bid 1 is already stopped
# seq_lens_decoder + 6 // 4 -> indices: [1, 1, 4]. Index 4 is out of range -> no trigger on bid 2
self.draft_tokens = paddle.to_tensor(
np.array(
[
[1, 1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3],
],
dtype=np.int64,
)
)
self.block_tables = paddle.to_tensor(np.full((self.real_bsz, self.block_num_per_seq), -1, dtype=np.int32))
# stop_flags length is max_bsz, others are real_bsz
self.stop_flags = paddle.to_tensor(np.array([False, True, False, False, False], dtype=np.bool_))
self.seq_lens_this_time = paddle.to_tensor(np.array([5, 6, 7], dtype=np.int32))
self.seq_lens_decoder = paddle.to_tensor(np.array([1, 1, 10], dtype=np.int32))
# Will be filled by kernel for the triggering bids only
self.step_seq_lens_decoder = paddle.zeros((self.real_bsz,), dtype="int32")
self.step_draft_tokens = paddle.zeros((self.real_bsz, self.draft_tokens_len), dtype="int64")
self.step_seq_lens_this_time = paddle.zeros((self.real_bsz,), dtype="int32")
# Intentionally non-zero so we can verify in-place zeroing only where triggered
self.accept_num = paddle.to_tensor(np.array([9, 8, 7], dtype=np.int32))
self.accept_tokens = paddle.to_tensor(
np.arange(self.real_bsz * self.accept_tokens_len, dtype=np.int64).reshape(
self.real_bsz, self.accept_tokens_len
)
)
self.is_block_step = paddle.zeros((self.real_bsz,), dtype=paddle.bool)
# not_need_stop lives on CPU in the caller; the kernel copies to device internally
self.not_need_stop = paddle.zeros((1,), dtype=paddle.bool).cpu()
# Choose threshold so with: bid0 triggers, bid1 already stopped, padding (5-3)=2 -> stop_sum = 1+1+2 = 4
# Set stop_nums to 5 so not_need_stop = (4 < 5) = True
self.stop_nums = paddle.to_tensor([5], dtype=paddle.int64)
# Keep NumPy copies for CPU reference
self.np_draft_tokens = self.draft_tokens.numpy().copy()
self.np_block_tables = self.block_tables.numpy().copy()
self.np_stop_flags = self.stop_flags.numpy().copy()
self.np_seq_lens_this_time = self.seq_lens_this_time.numpy().copy()
self.np_seq_lens_decoder = self.seq_lens_decoder.numpy().copy()
self.np_step_seq_lens_decoder = self.step_seq_lens_decoder.numpy().copy()
self.np_step_draft_tokens = self.step_draft_tokens.numpy().copy()
self.np_step_seq_lens_this_time = self.step_seq_lens_this_time.numpy().copy()
self.np_accept_num = self.accept_num.numpy().copy()
self.np_accept_tokens = self.accept_tokens.numpy().copy()
self.np_is_block_step = self.is_block_step.numpy().copy()
self.np_not_need_stop = self.not_need_stop.numpy().copy()
self.np_stop_nums = self.stop_nums.numpy().copy()
def test_correctness_against_cpu_reference(self):
# Run GPU kernel (in-place)
speculate_schedule_cache(
self.draft_tokens,
self.block_tables,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_decoder,
self.step_seq_lens_decoder,
self.step_draft_tokens,
self.step_seq_lens_this_time,
self.accept_num,
self.accept_tokens,
self.is_block_step,
self.not_need_stop,
self.stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Compute CPU reference (in-place on NumPy copies)
cpu_reference(
self.np_draft_tokens,
self.np_block_tables,
self.np_stop_flags,
self.np_seq_lens_this_time,
self.np_seq_lens_decoder,
self.np_step_seq_lens_decoder,
self.np_step_draft_tokens,
self.np_step_seq_lens_this_time,
self.np_accept_num,
self.np_accept_tokens,
self.np_is_block_step,
self.np_not_need_stop,
self.np_stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Compare all mutated tensors
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), self.np_step_draft_tokens)
np.testing.assert_array_equal(self.accept_tokens.numpy(), self.np_accept_tokens)
np.testing.assert_array_equal(self.stop_flags.numpy(), self.np_stop_flags)
np.testing.assert_array_equal(self.is_block_step.numpy(), self.np_is_block_step)
np.testing.assert_array_equal(self.seq_lens_this_time.numpy(), self.np_seq_lens_this_time)
np.testing.assert_array_equal(self.seq_lens_decoder.numpy(), self.np_seq_lens_decoder)
np.testing.assert_array_equal(self.step_seq_lens_decoder.numpy(), self.np_step_seq_lens_decoder)
np.testing.assert_array_equal(self.step_seq_lens_this_time.numpy(), self.np_step_seq_lens_this_time)
np.testing.assert_array_equal(self.accept_num.numpy(), self.np_accept_num)
self.assertEqual(bool(self.not_need_stop.numpy()[0]), bool(self.np_not_need_stop[0]))
def test_no_trigger_path(self):
# Make block_tables at candidate index != -1 so nothing triggers
# Candidate index for bid 0/1 is 1, set it to 7
bt = self.block_tables.numpy()
bt[:, 1] = 7
self.block_tables = paddle.to_tensor(bt)
# Reset outputs to distinctive values
self.step_seq_lens_decoder[:] = 0
self.step_draft_tokens[:] = 0
self.step_seq_lens_this_time[:] = 0
self.accept_num[:] = -123
self.accept_tokens[:] = -777
self.is_block_step[:] = False
self.not_need_stop[:] = False
# For not_need_stop: stopped_in_real = (bid1 True) = 1, padding = 2 -> stop_sum=3
# With stop_nums=5 -> True
speculate_schedule_cache(
self.draft_tokens,
self.block_tables,
self.stop_flags,
self.seq_lens_this_time,
self.seq_lens_decoder,
self.step_seq_lens_decoder,
self.step_draft_tokens,
self.step_seq_lens_this_time,
self.accept_num,
self.accept_tokens,
self.is_block_step,
self.not_need_stop,
self.stop_nums,
self.block_size,
self.max_draft_tokens,
)
# Nothing should have changed except not_need_stop
np.testing.assert_array_equal(self.step_draft_tokens.numpy(), np.zeros_like(self.step_draft_tokens.numpy()))
np.testing.assert_array_equal(self.is_block_step.numpy(), np.zeros_like(self.is_block_step.numpy()))
np.testing.assert_array_equal(self.accept_tokens.numpy(), np.full_like(self.accept_tokens.numpy(), -777))
np.testing.assert_array_equal(self.accept_num.numpy(), np.full_like(self.accept_num.numpy(), -123))
self.assertTrue(bool(self.not_need_stop.numpy()[0]))
if __name__ == "__main__":
unittest.main()