[Feature] support flash_mask_attention backend (#5134)

* [Feature] suppert flash_mask_attention backend

* fix unittest

* clean code
This commit is contained in:
lizhenyun01
2025-11-28 10:12:16 +08:00
committed by GitHub
parent b935101008
commit aba4fc657f
13 changed files with 542 additions and 69 deletions

View File

@@ -178,6 +178,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::Tensor& cache_batch_ids,
const paddle::Tensor& cache_tile_ids,
const paddle::Tensor& cache_num_blocks,
const paddle::optional<paddle::Tensor>& q_norm_weight,
const paddle::optional<paddle::Tensor>& k_norm_weight,
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
@@ -187,6 +189,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int kv_token_num,
const int max_seq_len,
const float rms_norm_eps,
const std::string& cache_quant_type,
const bool rope_3d);