Files
FastDeploy/tests/operators/test_update_attn_mask.py
freeliuzc 11398790d3
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
[Speculative Decoding][MTP]Support attn mask offset (#4641)
* [MTP]Merge support attn (#4591)

* support mask_offset in speculate decoding

* fix dummpy run output

* add unit test

* fix unit test import

* support attn_mask_offset in mtp mode

* add update_attn_mask op

* fix unit test && fix code-style
2025-11-03 10:08:01 +08:00

278 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import unittest
import numpy as np
import paddle
# 请确保你的编译后 op 在这个路径下可导入
from fastdeploy.model_executor.ops.gpu import update_attn_mask_offsets
def py_update_attn_mask_offsets_op(
ids_remove_padding_len,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_q,
attn_mask_offsets_full,
attn_mask_offsets_decoder,
is_block_step,
decode_states,
mask_rollback,
):
"""
Python-side reference op that mirrors the CUDA kernel you provided (latest version).
- ids_remove_padding_len: 总的去padding后 token 数(用于算 batch_seq_lens
- seq_lens_*: 1D numpy int32 arrays (len == bsz)
- cu_seqlens_q: 1D numpy int32 prefix sums (len == bsz)
- attn_mask_offsets_full: numpy array shape (bsz, max_model_len)
- attn_mask_offsets_decoder: 1D numpy int32 (bsz,)
- is_block_step: 1D bool array (bsz,)
- decode_states: numpy int32 array shape (bsz, decode_states_len)
- mask_rollback: 1D numpy int32 (bsz,) or shape (bsz,1)
Returns:
attn_mask_offsets_ref (1D int32 length batch_seq_lens * 2),
decode_states_ref (bsz x decode_states_len int32)
"""
# normalize inputs
seq_lens_this_time = np.array(seq_lens_this_time, dtype=np.int32).reshape(-1)
seq_lens_encoder = np.array(seq_lens_encoder, dtype=np.int32).reshape(-1)
seq_lens_decoder = np.array(seq_lens_decoder, dtype=np.int32).reshape(-1)
cu_seqlens_q = np.array(cu_seqlens_q, dtype=np.int32).reshape(-1)
is_block_step = np.array(is_block_step, dtype=bool).reshape(-1)
attn_mask_offsets_full = np.array(attn_mask_offsets_full, dtype=np.int32)
attn_mask_offsets_decoder = np.array(attn_mask_offsets_decoder, dtype=np.int32).reshape(-1)
decode_states = np.array(decode_states, dtype=np.int32).copy()
mask_rollback = np.array(mask_rollback, dtype=np.int32).reshape(-1)
bsz = int(seq_lens_this_time.shape[0])
total_seq = int(np.sum(seq_lens_this_time))
decode_states_len = int(decode_states.shape[1])
# CUDA creates paddle::full({batch_seq_lens * 2}, 0)
attn_mask_offsets = np.zeros((total_seq * 2,), dtype=np.int32)
for bid in range(bsz):
if is_block_step[bid]:
# skip update for this batch entry
continue
seq_len_this = int(seq_lens_this_time[bid])
seq_len_enc = int(seq_lens_encoder[bid])
seq_len_dec = int(seq_lens_decoder[bid])
query_start = int(cu_seqlens_q[bid])
# pointer-like views in C++: attn_mask_offsets_full_now, decode_states_now
full_now = attn_mask_offsets_full[bid]
decode_now = decode_states[bid] # this is a view into decode_states
# stop: both zero => do nothing
if seq_len_enc == 0 and seq_len_dec == 0:
continue
# prefill path (encoder > 0)
if seq_len_enc > 0:
for i in range(seq_len_this):
# vision generate phase check: (*decode_states_now == 2 && seq_len_decoder > 0)
# In C++ code they used '*decode_states_now == 2' — meaning first element compare.
if decode_now.size > 0 and decode_now[0] == 2 and seq_len_dec > 0:
attn_mask_offsets[(query_start + i) * 2 + 1] = seq_len_dec + seq_len_this
else:
# attn_mask_offsets_full_now[i] + 1
attn_mask_offsets[(query_start + i) * 2 + 1] = int(full_now[i]) + 1
# done prefill branch
continue
# decoder path (seq_len_decoder > 0)
if seq_len_dec > 0:
# subtract mask rollback
rollback = int(mask_rollback[bid]) if bid < mask_rollback.shape[0] else 0
attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) - rollback
start = int(attn_mask_offsets_decoder[bid])
for i in range(seq_len_this):
attn_mask_offsets[(query_start + i) * 2 + 1] = start + 1 + i
# advance decoder offset
attn_mask_offsets_decoder[bid] = int(attn_mask_offsets_decoder[bid]) + seq_len_this
# speculative decoding: if seq_len_this > 1 then set decode_states_now[i] accordingly
if seq_len_this > 1:
for i in range(decode_states_len):
decode_now[i] = 0 if i < seq_len_this else -1
# done decoder branch
continue
return attn_mask_offsets, decode_states
class UpdateAttnMaskOffsetsTestCase(unittest.TestCase):
def setUp(self):
# If GPU available, use it. But we don't hard require CUDA here; op itself must be callable.
# Ensure Paddle uses GPU if available to match operator placement
try:
paddle.set_device("gpu")
except Exception:
paddle.set_device("cpu")
def _call_and_compare(
self,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
is_block_step,
max_model_len=8,
decode_states_len=4,
vision_generate=False,
):
# build numpy inputs
seq_lens_this_time = np.array(seq_lens_this_time, dtype=np.int32).reshape(-1)
seq_lens_encoder = np.array(seq_lens_encoder, dtype=np.int32).reshape(-1)
seq_lens_decoder = np.array(seq_lens_decoder, dtype=np.int32).reshape(-1)
bsz = seq_lens_this_time.shape[0]
total_seq = int(np.sum(seq_lens_this_time))
cu_seqlens_q = np.zeros((bsz,), dtype=np.int32)
if bsz > 1:
cu_seqlens_q[1:] = np.cumsum(seq_lens_this_time[:-1])
# attn_mask_offsets_full: shape (bsz, max_model_len)
attn_mask_offsets_full = np.arange(bsz * max_model_len, dtype=np.int32).reshape(bsz, max_model_len)
# attn_mask_offsets_decoder initial (use seq_lens_decoder as seed for deterministic test)
attn_mask_offsets_decoder = np.array(seq_lens_decoder, dtype=np.int32).copy()
# decode_states initial
decode_states = np.full((bsz, decode_states_len), -1, dtype=np.int32)
if vision_generate:
decode_states[:, 0] = 2 # make first element 2 to trigger vision phase
mask_rollback = np.zeros((bsz,), dtype=np.int32)
# ids_remove_padding: length = total_seq (only length used by op)
ids_remove_padding = paddle.randint(low=0, high=10, shape=[total_seq], dtype="int32")
decode_states_tensor = paddle.to_tensor(decode_states, dtype="int32")
# prepare paddle tensors and call the compiled op
out = update_attn_mask_offsets(
ids_remove_padding,
paddle.to_tensor(seq_lens_this_time, dtype="int32"),
paddle.to_tensor(seq_lens_encoder, dtype="int32"),
paddle.to_tensor(seq_lens_decoder, dtype="int32"),
paddle.to_tensor(cu_seqlens_q, dtype="int32"),
paddle.to_tensor(attn_mask_offsets_full, dtype="int32"),
paddle.to_tensor(attn_mask_offsets_decoder, dtype="int32"),
paddle.to_tensor(np.array(is_block_step, dtype=bool).reshape(-1), dtype="bool"),
decode_states_tensor,
paddle.to_tensor(mask_rollback, dtype="int32"),
)
# op returns [attn_mask_offsets, decode_states_out] per your PD_BUILD_STATIC_OP outputs
if isinstance(out, (list, tuple)):
op_attn_mask_offsets = out[0].numpy().astype(np.int32).reshape(-1)
op_decode_states = out[1].numpy().astype(np.int32)
else:
# Some bindings might return single tensor and inplace decode_states update
# Try to handle that case: assume attn_mask_offsets returned and decode_states was mutated inplace.
op_attn_mask_offsets = out.numpy().astype(np.int32).reshape(-1)
# fetch decode_states by re-creating input decode_states tensor? best effort:
# (we passed decode_states as a paddle tensor; in operator we passed a copy, but PD set inplace mapping
# so many builds will actually give decode_states_out as second output; this block is fallback.)
op_decode_states = decode_states_tensor.numpy()
# compute python reference outputs
ref_attn_mask_offsets, ref_decode_states = py_update_attn_mask_offsets_op(
ids_remove_padding_len=total_seq,
seq_lens_this_time=seq_lens_this_time,
seq_lens_encoder=seq_lens_encoder,
seq_lens_decoder=seq_lens_decoder,
cu_seqlens_q=cu_seqlens_q,
attn_mask_offsets_full=attn_mask_offsets_full,
attn_mask_offsets_decoder=attn_mask_offsets_decoder.copy(),
is_block_step=np.array(is_block_step, dtype=bool).reshape(-1),
decode_states=decode_states.copy(),
mask_rollback=mask_rollback,
)
# optionally print debug if env var set
if os.environ.get("ATTN_MASK_TEST_DEBUG", "0") == "1":
print("=== DEBUG ===")
print("seq_lens_this_time:", seq_lens_this_time)
print("seq_lens_encoder:", seq_lens_encoder)
print("seq_lens_decoder:", seq_lens_decoder)
print("cu_seqlens_q:", cu_seqlens_q)
print("ref_attn_mask_offsets:", ref_attn_mask_offsets)
print("op_attn_mask_offsets:", op_attn_mask_offsets)
print("ref_decode_states:", ref_decode_states)
print("op_decode_states:", op_decode_states)
print("=============")
# shape checks
self.assertEqual(
op_attn_mask_offsets.shape,
ref_attn_mask_offsets.shape,
f"attn_mask_offsets shape mismatch: op {op_attn_mask_offsets.shape}, ref {ref_attn_mask_offsets.shape}",
)
# element-wise equality
np.testing.assert_array_equal(op_attn_mask_offsets, ref_attn_mask_offsets)
np.testing.assert_array_equal(op_decode_states, ref_decode_states)
# --- Test cases below (cover branches) ---
def test_stop_case(self):
# stop: both encoder and decoder are zero -> nothing written (all zeros)
self._call_and_compare(
seq_lens_this_time=[1],
seq_lens_encoder=[0],
seq_lens_decoder=[0],
is_block_step=[False],
max_model_len=4,
decode_states_len=2,
)
def test_prefill_case(self):
# prefill: encoder > 0, should copy attn_mask_offsets_full[i] + 1 into positions ((q+i)*2+1)
self._call_and_compare(
seq_lens_this_time=[3],
seq_lens_encoder=[3],
seq_lens_decoder=[0],
is_block_step=[False],
max_model_len=8,
decode_states_len=4,
)
def test_vision_generate_prefill(self):
# vision generate: decode_states[0] == 2 and seq_len_decoder > 0 triggers alternate write
self._call_and_compare(
seq_lens_this_time=[2],
seq_lens_encoder=[2],
seq_lens_decoder=[5], # >0 to activate vision branch
is_block_step=[False],
max_model_len=8,
decode_states_len=4,
vision_generate=True,
)
def test_decoder_case(self):
# decoder path: should write attn_mask_offsets_decoder - rollback + 1 .. +seq_len_this_time-1
self._call_and_compare(
seq_lens_this_time=[2],
seq_lens_encoder=[0],
seq_lens_decoder=[7],
is_block_step=[False],
max_model_len=8,
decode_states_len=6,
)
def test_mixed_batch_case(self):
# mixed batch with different statuses
self._call_and_compare(
seq_lens_this_time=[2, 4, 1],
seq_lens_encoder=[0, 4, 0],
seq_lens_decoder=[5, 0, 1],
is_block_step=[False, False, False],
max_model_len=12,
decode_states_len=2,
)
if __name__ == "__main__":
unittest.main()