From 8e1b35a09b0d0d1acaa58c6c8ece03a05c4c0207 Mon Sep 17 00:00:00 2001 From: yangjianfengo1 <125249383+yangjianfengo1@users.noreply.github.com> Date: Tue, 2 Sep 2025 19:17:01 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Fix=20bug]=20=20w4afp8=20=E7=9A=84nblo?= =?UTF-8?q?ck=E5=9B=BA=E5=AE=9A=E4=B8=BA256=EF=BC=8C=E5=B9=B6=E4=B8=94fa3?= =?UTF-8?q?=E7=9A=84append=20attn=20=E5=A2=9E=E5=8A=A0mask=E5=8F=82?= =?UTF-8?q?=E6=95=B0=20(#3771)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix w4afp8 * 增加集中式配置 * codestyle * fix fa3 append attn --- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu | 6 +----- custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py | 2 ++ .../model_executor/layers/attention/flash_attn_backend.py | 1 + 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu index ed1071538..53685c5c9 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu @@ -75,12 +75,8 @@ void DisPatchW4AFp8Gemm( const int64_t K, cudaStream_t stream) { - int kBlockN = (max_tokens + 15) / 16 * 16; + int kBlockN = 256; int TailN = 0; - if (kBlockN > 256) { - TailN = kBlockN % 256; - kBlockN = 256; - } if constexpr (std::is_same_v) { GEMM_SWITCH_BF16( M, K, batch_size, token_padding_size, kBlockN, TailN, diff --git a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py index 1529fa8de..1acf3c80a 100644 --- a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py +++ b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py @@ -88,6 +88,8 @@ gemm_case = [ [8192, 3584, 8, 2048], # eb45T ffn1 [7168, 8192, 8, 0], # eb45T ffn2 [7168, 8192, 8, 2048], # eb45T ffn2 + [1792, 8192, 64, 0], # eb45t ffn1 + [8192, 896, 64, 0], # eb45t ffn2 ] dtype = ["BF16"] diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 193a31ff5..8f220ddb9 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -359,6 +359,7 @@ class FlashAttentionBackend(AttentionBackend): getattr(layer, "cache_v_zp", None), layer.linear_shift, layer.linear_smooth, + forward_meta.attn_mask_offsets, metadata.kv_signal_data_list[layer.layer_id], getattr(layer, "q_norm_weight", None), getattr(layer, "k_norm_weight", None),