mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
【Fix bug] w4afp8 的nblock固定为256,并且fa3的append attn 增加mask参数 (#3771)
* fix w4afp8 * 增加集中式配置 * codestyle * fix fa3 append attn
This commit is contained in:
@@ -75,12 +75,8 @@ void DisPatchW4AFp8Gemm(
|
|||||||
const int64_t K,
|
const int64_t K,
|
||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
|
|
||||||
int kBlockN = (max_tokens + 15) / 16 * 16;
|
int kBlockN = 256;
|
||||||
int TailN = 0;
|
int TailN = 0;
|
||||||
if (kBlockN > 256) {
|
|
||||||
TailN = kBlockN % 256;
|
|
||||||
kBlockN = 256;
|
|
||||||
}
|
|
||||||
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
|
if constexpr (std::is_same_v<OutputType, cutlass::bfloat16_t>) {
|
||||||
GEMM_SWITCH_BF16(
|
GEMM_SWITCH_BF16(
|
||||||
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
M, K, batch_size, token_padding_size, kBlockN, TailN,
|
||||||
|
@@ -88,6 +88,8 @@ gemm_case = [
|
|||||||
[8192, 3584, 8, 2048], # eb45T ffn1
|
[8192, 3584, 8, 2048], # eb45T ffn1
|
||||||
[7168, 8192, 8, 0], # eb45T ffn2
|
[7168, 8192, 8, 0], # eb45T ffn2
|
||||||
[7168, 8192, 8, 2048], # eb45T ffn2
|
[7168, 8192, 8, 2048], # eb45T ffn2
|
||||||
|
[1792, 8192, 64, 0], # eb45t ffn1
|
||||||
|
[8192, 896, 64, 0], # eb45t ffn2
|
||||||
]
|
]
|
||||||
|
|
||||||
dtype = ["BF16"]
|
dtype = ["BF16"]
|
||||||
|
@@ -359,6 +359,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
getattr(layer, "cache_v_zp", None),
|
getattr(layer, "cache_v_zp", None),
|
||||||
layer.linear_shift,
|
layer.linear_shift,
|
||||||
layer.linear_smooth,
|
layer.linear_smooth,
|
||||||
|
forward_meta.attn_mask_offsets,
|
||||||
metadata.kv_signal_data_list[layer.layer_id],
|
metadata.kv_signal_data_list[layer.layer_id],
|
||||||
getattr(layer, "q_norm_weight", None),
|
getattr(layer, "q_norm_weight", None),
|
||||||
getattr(layer, "k_norm_weight", None),
|
getattr(layer, "k_norm_weight", None),
|
||||||
|
Reference in New Issue
Block a user