【Fix Bug】 修复 fa3 支持集中式bug (#3235)
Some checks failed
Deploy GitHub Pages / deploy (push) Has been cancelled

* fix fa3 集中式bug

* 增加qknorm参数
This commit is contained in:
yangjianfengo1
2025-08-06 16:24:27 +08:00
committed by GitHub
parent afff4d37ea
commit 3a15e0c53e

View File

@@ -344,7 +344,7 @@ class FlashAttentionBackend(AttentionBackend):
forward_meta.decoder_batch_ids, # from buffer forward_meta.decoder_batch_ids, # from buffer
forward_meta.decoder_tile_ids_per_batch, # from buffer forward_meta.decoder_tile_ids_per_batch, # from buffer
forward_meta.decoder_num_blocks_cpu, forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu, metadata.max_len_tensor_cpu_decoder,
metadata.max_len_kv, metadata.max_len_kv,
metadata.rotary_embs, metadata.rotary_embs,
forward_meta.attn_mask, forward_meta.attn_mask,
@@ -359,6 +359,9 @@ class FlashAttentionBackend(AttentionBackend):
layer.linear_shift, layer.linear_shift,
layer.linear_smooth, layer.linear_smooth,
metadata.kv_signal_data_list[layer.layer_id], metadata.kv_signal_data_list[layer.layer_id],
getattr(layer, "q_norm_weight", None),
getattr(layer, "k_norm_weight", None),
getattr(layer, "rms_norm_eps", 1e-6),
metadata._fuse_kernel_compute_dtype, metadata._fuse_kernel_compute_dtype,
getattr(layer, "cache_quant_type_str", "none"), getattr(layer, "cache_quant_type_str", "none"),
layer.use_neox_rotary_style, layer.use_neox_rotary_style,