diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu index ed1071538..53685c5c9 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu @@ -75,12 +75,8 @@ void DisPatchW4AFp8Gemm( const int64_t K, cudaStream_t stream) { - int kBlockN = (max_tokens + 15) / 16 * 16; + int kBlockN = 256; int TailN = 0; - if (kBlockN > 256) { - TailN = kBlockN % 256; - kBlockN = 256; - } if constexpr (std::is_same_v) { GEMM_SWITCH_BF16( M, K, batch_size, token_padding_size, kBlockN, TailN, diff --git a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py index 1529fa8de..1acf3c80a 100644 --- a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py +++ b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py @@ -88,6 +88,8 @@ gemm_case = [ [8192, 3584, 8, 2048], # eb45T ffn1 [7168, 8192, 8, 0], # eb45T ffn2 [7168, 8192, 8, 2048], # eb45T ffn2 + [1792, 8192, 64, 0], # eb45t ffn1 + [8192, 896, 64, 0], # eb45t ffn2 ] dtype = ["BF16"] diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 193a31ff5..8f220ddb9 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -359,6 +359,7 @@ class FlashAttentionBackend(AttentionBackend): getattr(layer, "cache_v_zp", None), layer.linear_shift, layer.linear_smooth, + forward_meta.attn_mask_offsets, metadata.kv_signal_data_list[layer.layer_id], getattr(layer, "q_norm_weight", None), getattr(layer, "k_norm_weight", None),