mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
format flash_mask_attn
This commit is contained in:
@@ -188,7 +188,7 @@ struct Softmax {
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
CUTLASS_DEVICE Softmax() {};
|
||||
CUTLASS_DEVICE Softmax(){};
|
||||
|
||||
template <bool Is_first, bool Check_inf = false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT max(Tensor0 &acc_s,
|
||||
|
||||
@@ -56,7 +56,7 @@ class TestFlashMaskAttention(unittest.TestCase):
|
||||
out[bsz, hi] = (np.matmul(qk, v_cur[bsz, hi // gqa_group_size]) * exp_sum_inv).astype(q_input.dtype)
|
||||
return out
|
||||
|
||||
def paddle_flash_attn_mask(self, q_input, k_input, v_input, mask):
|
||||
def paddle_flash_attn_mask(self, q_input, k_input, v_input, attn_out, mask):
|
||||
bsz = q_input.shape[0]
|
||||
cu_seq_q = paddle.arange(bsz + 1) * q_input.shape[1]
|
||||
cu_seq_k = paddle.arange(bsz + 1) * k_input.shape[1]
|
||||
@@ -71,13 +71,14 @@ class TestFlashMaskAttention(unittest.TestCase):
|
||||
v_input_pad[0 : v_input.shape[0]] = v_input
|
||||
mask = paddle.to_tensor(mask).astype("int32")
|
||||
|
||||
out = flash_attention_mask(
|
||||
flash_attention_mask(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input_pad,
|
||||
cu_seq_q,
|
||||
cu_seq_k,
|
||||
seq_len_encoder,
|
||||
attn_out,
|
||||
mask,
|
||||
int(q_input.shape[1]),
|
||||
int(k_input.shape[1]),
|
||||
@@ -86,7 +87,6 @@ class TestFlashMaskAttention(unittest.TestCase):
|
||||
int(q_input.shape[0]),
|
||||
int(k_input.shape[0]),
|
||||
)
|
||||
return out
|
||||
|
||||
def test_flash_attention_mask(self):
|
||||
q_input = np.random.normal(0, 0.5, size=(self.bsz, self.q_seq_len, self.num_head, self.head_dim))
|
||||
@@ -105,7 +105,8 @@ class TestFlashMaskAttention(unittest.TestCase):
|
||||
mask[text_len : text_len + image_len] = text_len + image_len + self.k_seq_len
|
||||
|
||||
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
|
||||
paddle_attn_out = self.paddle_flash_attn_mask(q_input, k_input, v_input, mask)
|
||||
paddle_attn_out = paddle.zeros(naive_attn_out.shape, dtype="bfloat16")
|
||||
self.paddle_flash_attn_mask(q_input, k_input, v_input, paddle_attn_out, mask)
|
||||
|
||||
max_diff = float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max())
|
||||
self.assertLessEqual(max_diff, 0.05)
|
||||
|
||||
Reference in New Issue
Block a user