From d11235333e4d19199c83e998499ceda74252b511 Mon Sep 17 00:00:00 2001 From: lizhenyun01 <1500424927@qq.com> Date: Tue, 18 Nov 2025 13:33:37 +0800 Subject: [PATCH] format flash_mask_attn --- custom_ops/gpu_ops/flash_mask_attn/softmax.hpp | 2 +- tests/operators/test_flash_mask_attn.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/custom_ops/gpu_ops/flash_mask_attn/softmax.hpp b/custom_ops/gpu_ops/flash_mask_attn/softmax.hpp index 386f67d8c..cd48d349b 100644 --- a/custom_ops/gpu_ops/flash_mask_attn/softmax.hpp +++ b/custom_ops/gpu_ops/flash_mask_attn/softmax.hpp @@ -188,7 +188,7 @@ struct Softmax { using TensorT = decltype(make_tensor(Shape>{})); TensorT row_max, row_sum; - CUTLASS_DEVICE Softmax() {}; + CUTLASS_DEVICE Softmax(){}; template __forceinline__ __device__ TensorT max(Tensor0 &acc_s, diff --git a/tests/operators/test_flash_mask_attn.py b/tests/operators/test_flash_mask_attn.py index d7849406d..2ada04527 100644 --- a/tests/operators/test_flash_mask_attn.py +++ b/tests/operators/test_flash_mask_attn.py @@ -56,7 +56,7 @@ class TestFlashMaskAttention(unittest.TestCase): out[bsz, hi] = (np.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype) return out - def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask): + def paddle_flash_attn_mask(self, q_input, k_input, v_input, attn_out, mask): bsz = q_input.shape[0] cu_seq_q = paddle.arange(bsz + 1) * q_input.shape[1] cu_seq_k = paddle.arange(bsz + 1) * k_input.shape[1] @@ -71,13 +71,14 @@ class TestFlashMaskAttention(unittest.TestCase): v_input_pad[0 : v_input.shape[0]] = v_input mask = paddle.to_tensor(mask).astype("int32") - out = flash_attention_mask( + flash_attention_mask( q_input, k_input, v_input_pad, cu_seq_q, cu_seq_k, seq_len_encoder, + attn_out, mask, int(q_input.shape[1]), int(k_input.shape[1]), @@ -86,7 +87,6 @@ class TestFlashMaskAttention(unittest.TestCase): int(q_input.shape[0]), int(k_input.shape[0]), ) - return out def test_flash_attention_mask(self): q_input = np.random.normal(0, 0.5, size=(self.bsz, self.q_seq_len, self.num_head, self.head_dim)) @@ -105,7 +105,8 @@ class TestFlashMaskAttention(unittest.TestCase): mask[text_len : text_len + image_len] = text_len + image_len + self.k_seq_len naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask) - paddle_attn_out = self.paddle_flash_attn_mask(q_input, k_input, v_input, mask) + paddle_attn_out = paddle.zeros(naive_attn_out.shape, dtype="bfloat16") + self.paddle_flash_attn_mask(q_input, k_input, v_input, paddle_attn_out, mask) max_diff = float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max()) self.assertLessEqual(max_diff, 0.05)