format flash_mask_attn

This commit is contained in:
lizhenyun01
2025-11-18 13:33:37 +08:00
parent cd2c4df64a
commit d11235333e
2 changed files with 6 additions and 5 deletions

View File

@@ -188,7 +188,7 @@ struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
CUTLASS_DEVICE Softmax() {};
CUTLASS_DEVICE Softmax(){};
template <bool Is_first, bool Check_inf = false, typename Tensor0>
__forceinline__ __device__ TensorT max(Tensor0 &acc_s,

View File

@@ -56,7 +56,7 @@ class TestFlashMaskAttention(unittest.TestCase):
out[bsz, hi] = (np.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype)
return out
def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask):
def paddle_flash_attn_mask(self, q_input, k_input, v_input, attn_out, mask):
bsz = q_input.shape[0]
cu_seq_q = paddle.arange(bsz + 1) * q_input.shape[1]
cu_seq_k = paddle.arange(bsz + 1) * k_input.shape[1]
@@ -71,13 +71,14 @@ class TestFlashMaskAttention(unittest.TestCase):
v_input_pad[0 : v_input.shape[0]] = v_input
mask = paddle.to_tensor(mask).astype("int32")
out = flash_attention_mask(
flash_attention_mask(
q_input,
k_input,
v_input_pad,
cu_seq_q,
cu_seq_k,
seq_len_encoder,
attn_out,
mask,
int(q_input.shape[1]),
int(k_input.shape[1]),
@@ -86,7 +87,6 @@ class TestFlashMaskAttention(unittest.TestCase):
int(q_input.shape[0]),
int(k_input.shape[0]),
)
return out
def test_flash_attention_mask(self):
q_input = np.random.normal(0, 0.5, size=(self.bsz, self.q_seq_len, self.num_head, self.head_dim))
@@ -105,7 +105,8 @@ class TestFlashMaskAttention(unittest.TestCase):
mask[text_len : text_len + image_len] = text_len + image_len + self.k_seq_len
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
paddle_attn_out = self.paddle_flash_attn_mask(q_input, k_input, v_input, mask)
paddle_attn_out = paddle.zeros(naive_attn_out.shape, dtype="bfloat16")
self.paddle_flash_attn_mask(q_input, k_input, v_input, paddle_attn_out, mask)
max_diff = float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max())
self.assertLessEqual(max_diff, 0.05)