mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
* [Feature] suppert flash_mask_attention backend * fix unittest * clean code
This commit is contained in:
@@ -24,6 +24,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const float *q_norm_weight,
|
||||
const float *k_norm_weight,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens,
|
||||
@@ -38,37 +40,46 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
const int kv_num_head,
|
||||
const int seq_len,
|
||||
const int last_dim,
|
||||
const bool rope_3d) {
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
using LoadFloat = AlignedVector<float, VecSize>;
|
||||
LoadT src_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
LoadFloat tmp_vec;
|
||||
LoadFloat q_norm_vec, k_norm_vec;
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
const int half_lastdim = last_dim / 2;
|
||||
const int offset = (q_num_head + kv_num_head * 2) * last_dim;
|
||||
for (int64_t linear_index = global_thread_idx * VecSize,
|
||||
step = gridDim.x * blockDim.x * VecSize;
|
||||
linear_index < elem_cnt;
|
||||
linear_index += step) {
|
||||
const int token_idx = linear_index / offset;
|
||||
const int ori_bi = batch_id_per_token[token_idx];
|
||||
const int offset =
|
||||
(q_num_head + kv_num_head * 2) * last_dim; // for all q,k,v
|
||||
const int all_head_num = elem_cnt / last_dim;
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num;
|
||||
gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index =
|
||||
gloabl_hi * last_dim + threadIdx.x * VecSize; // 全局index
|
||||
const int token_idx =
|
||||
linear_index / offset; // token id(第几个token,不分qkv)
|
||||
const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch
|
||||
if (seq_lens[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / last_dim;
|
||||
const int h_bias = bias % last_dim;
|
||||
|
||||
const int ori_seq_id =
|
||||
(token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
|
||||
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
|
||||
int64_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
(token_idx - cu_seqlens_q[ori_bi]) +
|
||||
seq_lens_decoder
|
||||
[ori_bi]; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
|
||||
const int64_t emb_idx =
|
||||
ori_seq_id * half_lastdim + h_bias / 2; // embedding的id
|
||||
const int64_t base_idx =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
|
||||
h_bias;
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
int64_t base_split_idx;
|
||||
T *out_p = nullptr;
|
||||
if (hi < q_num_head) {
|
||||
@@ -84,21 +95,67 @@ __global__ void GQAVariableLengthRotarySplitKernel(
|
||||
base_split_idx = kv_write_idx * kv_num_head * last_dim +
|
||||
(hi - q_num_head - kv_num_head) * last_dim + h_bias;
|
||||
}
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
// do rope
|
||||
if (hi < q_num_head + kv_num_head) {
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
|
||||
// TODO check this correct or not
|
||||
int64_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
if (q_norm_weight && k_norm_weight) {
|
||||
if (hi < q_num_head + kv_num_head) { // only q and k need rope
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
src_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
src_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
tmp_vec[2 * i] = tmp1;
|
||||
tmp_vec[2 * i + 1] = tmp2;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
}
|
||||
}
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2); // 单个head的标准差
|
||||
|
||||
if (hi < q_num_head + kv_num_head) { // only q and k need norm
|
||||
float row_variance = max(warp_m2 / last_dim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
if (hi < q_num_head) {
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize],
|
||||
&q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
src_vec[i] =
|
||||
static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else {
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize],
|
||||
&k_norm_vec);
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
src_vec[i] =
|
||||
static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (hi < q_num_head + kv_num_head) {
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
const float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
const float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
src_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
src_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
|
||||
@@ -114,6 +171,8 @@ void gqa_rotary_qk_split_variable(
|
||||
T *v,
|
||||
const T *qkv_input,
|
||||
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
|
||||
const float *q_norm_weight,
|
||||
const float *k_norm_weight,
|
||||
const int *batch_id_per_token,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
@@ -126,24 +185,31 @@ void gqa_rotary_qk_split_variable(
|
||||
const int input_output_len,
|
||||
const int dim_head,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps,
|
||||
const cudaStream_t &stream) {
|
||||
assert(dim_head == 128 && "dim_head must be 128");
|
||||
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
constexpr int PackSize = 16 / sizeof(T);
|
||||
|
||||
constexpr int HEAD_DIM = 128;
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
dim3 block_size(kWarpSize, blocksize / kWarpSize);
|
||||
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
|
||||
launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel<T, PackSize>,
|
||||
grid_size,
|
||||
blocksize,
|
||||
block_size,
|
||||
0,
|
||||
stream,
|
||||
qkv_input,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
@@ -158,7 +224,8 @@ void gqa_rotary_qk_split_variable(
|
||||
kv_num_heads,
|
||||
seq_len,
|
||||
dim_head,
|
||||
rope_3d);
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -1054,6 +1121,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::Tensor &cache_batch_ids,
|
||||
const paddle::Tensor &cache_tile_ids,
|
||||
const paddle::Tensor &cache_num_blocks,
|
||||
const paddle::optional<paddle::Tensor> &q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor> &k_norm_weight,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor> &cache_k_dequant_scales,
|
||||
@@ -1063,6 +1132,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::optional<paddle::Tensor> &kv_signal_data,
|
||||
const int kv_token_num,
|
||||
const int max_seq_len,
|
||||
const float rms_norm_eps,
|
||||
const std::string &cache_quant_type,
|
||||
const bool rope_3d) {
|
||||
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
|
||||
@@ -1113,6 +1183,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
@@ -1125,6 +1197,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
|
||||
head_dim,
|
||||
rope_3d,
|
||||
rms_norm_eps,
|
||||
stream);
|
||||
|
||||
if (token_num < kv_token_num) {
|
||||
@@ -1259,6 +1332,8 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
|
||||
"cache_batch_ids",
|
||||
"cache_tile_ids_per_batch",
|
||||
"cache_num_blocks",
|
||||
paddle::Optional("q_norm_weight"),
|
||||
paddle::Optional("k_norm_weight"),
|
||||
paddle::Optional("cache_k_quant_scales"),
|
||||
paddle::Optional("cache_v_quant_scales"),
|
||||
paddle::Optional("cache_k_dequant_scales"),
|
||||
@@ -1271,5 +1346,7 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
|
||||
{"value_cache", "value_cache_out"}})
|
||||
.Attrs({"kv_token_num: int",
|
||||
"max_seq_len: int",
|
||||
"cache_quant_type: std::string"})
|
||||
"rms_norm_eps: float",
|
||||
"cache_quant_type: std::string",
|
||||
"rope_3d: bool"})
|
||||
.SetKernelFn(PD_KERNEL(GQARopeWriteCacheKernel));
|
||||
|
||||
@@ -178,6 +178,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::Tensor& cache_batch_ids,
|
||||
const paddle::Tensor& cache_tile_ids,
|
||||
const paddle::Tensor& cache_num_blocks,
|
||||
const paddle::optional<paddle::Tensor>& q_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& k_norm_weight,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_v_quant_scales,
|
||||
const paddle::optional<paddle::Tensor>& cache_k_dequant_scales,
|
||||
@@ -187,6 +189,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const paddle::optional<paddle::Tensor>& kv_signal_data,
|
||||
const int kv_token_num,
|
||||
const int max_seq_len,
|
||||
const float rms_norm_eps,
|
||||
const std::string& cache_quant_type,
|
||||
const bool rope_3d);
|
||||
|
||||
|
||||
@@ -46,12 +46,11 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_len,
|
||||
const int max_enc_len_this_time,
|
||||
const int max_dec_len_this_time) {
|
||||
const int q_token_num,
|
||||
const int k_token_num) {
|
||||
constexpr int kBlockM = 128;
|
||||
constexpr int kBlockN = 128;
|
||||
const int batch_size = seq_len_encoder.dims()[0];
|
||||
|
||||
const int batch_size = cu_seq_k.dims()[0] - 1;
|
||||
Flash_mask_params params;
|
||||
memset(¶ms, 0, sizeof(Flash_mask_params));
|
||||
|
||||
@@ -63,8 +62,8 @@ void DispatchFlashAttentionMask(const paddle::Tensor& q_input,
|
||||
params.seq_len_encoder = const_cast<int*>(seq_len_encoder.data<int>());
|
||||
params.head_num = head_num;
|
||||
params.kv_head_num = kv_head_num;
|
||||
params.max_seq_len_q = max_enc_len_this_time;
|
||||
params.max_seq_len_k = max_enc_len_this_time + max_dec_len_this_time;
|
||||
params.q_token_num = q_token_num;
|
||||
params.k_token_num = k_token_num;
|
||||
params.batch_size = batch_size;
|
||||
params.gqa_group_size = head_num / kv_head_num;
|
||||
constexpr float kLog2e = 1.4426950408889634074;
|
||||
@@ -132,8 +131,8 @@ void FlashAttentionMask(const paddle::Tensor& q_input,
|
||||
const int kv_head_num,
|
||||
const int head_dim,
|
||||
const int max_seq_len,
|
||||
const int max_enc_len_this_time,
|
||||
const int max_dec_len_this_time) {
|
||||
const int q_token_num,
|
||||
const int k_token_num) {
|
||||
if (q_input.dtype() == paddle::DataType::FLOAT16) {
|
||||
using T = phi::dtype::float16;
|
||||
DispatchFlashAttentionMask<T>(q_input,
|
||||
@@ -148,8 +147,8 @@ void FlashAttentionMask(const paddle::Tensor& q_input,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time);
|
||||
q_token_num,
|
||||
k_token_num);
|
||||
} else if (q_input.dtype() == paddle::DataType::BFLOAT16) {
|
||||
using T = phi::dtype::bfloat16;
|
||||
DispatchFlashAttentionMask<T>(q_input,
|
||||
@@ -164,12 +163,12 @@ void FlashAttentionMask(const paddle::Tensor& q_input,
|
||||
kv_head_num,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
max_enc_len_this_time,
|
||||
max_dec_len_this_time);
|
||||
q_token_num,
|
||||
k_token_num);
|
||||
}
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(flash_attention_mask)
|
||||
PD_BUILD_STATIC_OP(flash_mask_attention)
|
||||
.Inputs({"q_input",
|
||||
"k_input",
|
||||
"v_input",
|
||||
@@ -182,8 +181,8 @@ PD_BUILD_STATIC_OP(flash_attention_mask)
|
||||
"kv_head_num: int",
|
||||
"head_dim: int",
|
||||
"max_seq_len: int",
|
||||
"max_enc_len_this_time: int",
|
||||
"max_dec_len_this_time: int"})
|
||||
"q_token_num: int",
|
||||
"k_token_num: int"})
|
||||
.Outputs({"out"})
|
||||
.SetInplaceMap({{"attn_out", "out"}})
|
||||
.SetKernelFn(PD_KERNEL(FlashAttentionMask));
|
||||
|
||||
@@ -59,19 +59,26 @@ __global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
|
||||
auto &shared_storage =
|
||||
*reinterpret_cast<typename Ktraits::SharedStorage *>(shared_memory);
|
||||
|
||||
__align__(16) __shared__ int mask[kBlockM];
|
||||
__align__(16) __shared__ int mask_end[kBlockM];
|
||||
__align__(16) __shared__ int mask_start[kBlockM];
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
const int bidh = blockIdx.y;
|
||||
const int bidb = blockIdx.z;
|
||||
if (data_params.seq_len_encoder[bidb] <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if constexpr (NeedMask) {
|
||||
const int *mask_this_batch =
|
||||
data_params.mask + data_params.cu_seq_q[bidb] + m_block * kBlockM;
|
||||
const int2 *mask_this_batch =
|
||||
reinterpret_cast<int2 *>(data_params.mask) +
|
||||
(data_params.cu_seq_q[bidb] + m_block * kBlockM);
|
||||
|
||||
for (int i = threadIdx.x; i < kBlockM;
|
||||
i += Ktraits::kNWarps * cutlass::NumThreadsPerWarp) {
|
||||
mask[i] = mask_this_batch[i];
|
||||
int2 mask_value = mask_this_batch[i];
|
||||
mask_start[i] = mask_value.x;
|
||||
mask_end[i] = mask_value.y;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,7 +126,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
|
||||
|
||||
const int n_block_max =
|
||||
NeedMask
|
||||
? cute::ceil_div(mask[min(kBlockM - 1, real_seq - 1)], kBlockN)
|
||||
? cute::ceil_div(mask_end[min(kBlockM - 1, real_seq - 1)], kBlockN)
|
||||
: min(cute::ceil_div((m_block + 1) * kBlockM + seq_len_k - seq_len_q,
|
||||
kBlockN),
|
||||
cute::ceil_div(seq_len_k, kBlockN));
|
||||
@@ -170,7 +177,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps *cutlass::NumThreadsPerWarp,
|
||||
smem_pipe_read_v,
|
||||
tOrO,
|
||||
softmax,
|
||||
mask,
|
||||
mask_end,
|
||||
n_block_max,
|
||||
threadIdx.x - NumCopyThreads,
|
||||
m_block,
|
||||
@@ -207,18 +214,15 @@ void run_flash_mask(Flash_mask_params ¶ms, cudaStream_t stream) {
|
||||
typename CollectiveMainloop::Params mainloop_params =
|
||||
CollectiveMainloop::to_underlying_arguments(
|
||||
{static_cast<Element const *>(params.q_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_q * params.batch_size,
|
||||
params.head_num),
|
||||
get_gmem_layout<kHeadDim>(params.q_token_num, params.head_num),
|
||||
static_cast<Element const *>(params.k_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_k * params.batch_size,
|
||||
params.kv_head_num),
|
||||
get_gmem_layout<kHeadDim>(params.k_token_num, params.kv_head_num),
|
||||
static_cast<Element const *>(params.v_ptr),
|
||||
get_gmem_layout<kHeadDim>(params.max_seq_len_k * params.batch_size,
|
||||
params.kv_head_num),
|
||||
get_gmem_layout<kHeadDim>(params.k_token_num, params.kv_head_num),
|
||||
params.scale_softmax_log2});
|
||||
|
||||
int num_blocks_m =
|
||||
cutlass::ceil_div(params.max_seq_len_q, Kernel_traits::kBlockM);
|
||||
cutlass::ceil_div(params.q_token_num, Kernel_traits::kBlockM);
|
||||
|
||||
num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) *
|
||||
size<0>(ClusterShape{});
|
||||
|
||||
@@ -35,8 +35,8 @@ struct Flash_mask_params {
|
||||
int *seq_len_encoder;
|
||||
int head_num;
|
||||
int kv_head_num;
|
||||
int max_seq_len_q;
|
||||
int max_seq_len_k;
|
||||
int q_token_num;
|
||||
int k_token_num;
|
||||
int batch_size;
|
||||
int gqa_group_size;
|
||||
float scale_softmax_log2;
|
||||
|
||||
@@ -18,6 +18,7 @@ from .attention_selecter import get_attention_backend
|
||||
from .base_attention_backend import AttentionBackend
|
||||
from .block_multihead_attn_backend import BlockAttentionBackend
|
||||
from .flash_attn_backend import FlashAttentionBackend
|
||||
from .flash_mask_attn_backend import FlashMaskAttentionBackend
|
||||
from .iluvatar_attn_backend import IluvatarAttnBackend
|
||||
from .mla_attention_backend import MLAAttentionBackend
|
||||
from .moba_attention_backend import PlasAttentionBackend
|
||||
@@ -36,4 +37,5 @@ __all__ = [
|
||||
"BlockAttentionBackend",
|
||||
"Attention",
|
||||
"PlasAttentionBackend",
|
||||
"FlashMaskAttentionBackend",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionBackend,
|
||||
AttentionMetadata,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.ops import (
|
||||
flash_mask_attention,
|
||||
get_block_shape_and_split_kv_block,
|
||||
gqa_rope_write_cache,
|
||||
init_kv_signal_per_query,
|
||||
init_signal_layerwise,
|
||||
open_shm_and_get_meta_signal,
|
||||
pre_cache_len_concat,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import merge_prefill_decode_output
|
||||
else:
|
||||
merge_prefill_decode_output = None
|
||||
|
||||
import os
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMaskAttentionMetadata(AttentionMetadata):
|
||||
"""
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
max_seqlen_q: int = 0
|
||||
max_seqlen_k: int = 0
|
||||
|
||||
pre_cache_batch_ids = None
|
||||
pre_cache_tile_ids_per_batch = None
|
||||
pre_cache_num_blocks_cpu = None
|
||||
kv_token_num_cpu = None
|
||||
|
||||
# pd_disaggregation
|
||||
kv_signal_metadata: Optional[paddle.Tensor] = None
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
|
||||
max_len_tensor_cpu: paddle.Tensor = None
|
||||
max_len_tensor_cpu_decoder: paddle.Tensor = None
|
||||
|
||||
|
||||
class FlashMaskAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
FlashAttentionBackend backend implementation
|
||||
"""
|
||||
|
||||
__infer_dynamic_dims_fields__ = ["attention_metadata"]
|
||||
attention_metadata: FlashMaskAttentionMetadata
|
||||
flash_attn_func: callable = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fd_config: FDConfig,
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
encoder_block_shape_q: int = -1,
|
||||
decoder_block_shape_q: int = -1,
|
||||
):
|
||||
"""
|
||||
FlashAttentionBackend __init__
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: FlashMaskAttentionMetadata = None
|
||||
self.max_seq_len = fd_config.model_config.max_model_len
|
||||
self.causal = getattr(fd_config.model_config, "causal", True)
|
||||
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.attn_outputsize_tp = self.num_heads * self.head_dim
|
||||
self.block_size = fd_config.cache_config.block_size
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
self.speculative_method = fd_config.speculative_config.method
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.num_layers_draft_model: int = int(fd_config.speculative_config.method in ["mtp"])
|
||||
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
self.start_layer_index: int = fd_config.model_config.start_layer_index
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
fd_config.parallel_config.expert_parallel_rank = 0
|
||||
|
||||
self.rank, self.device_id = init_rank_and_device_id(fd_config)
|
||||
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) or getattr(
|
||||
fd_config.model_config, "use_3d_rope", False
|
||||
)
|
||||
if fd_config.speculative_config.model_type != "main":
|
||||
self.rope_3d = False
|
||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", "32768"))
|
||||
self.zero_seq_enc_lens_for_decode = paddle.zeros(
|
||||
shape=[fd_config.scheduler_config.max_num_seqs, 1], dtype=paddle.int32
|
||||
)
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
|
||||
def get_kv_cache_shape(
|
||||
self,
|
||||
max_num_blocks: int,
|
||||
kv_cache_quant_type: str = None,
|
||||
):
|
||||
"""
|
||||
Calculate kv cache shape
|
||||
"""
|
||||
key_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
|
||||
value_cache_shape = [max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim]
|
||||
if kv_cache_quant_type is not None and kv_cache_quant_type == "int4_zp":
|
||||
key_cache_shape = [
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
]
|
||||
value_cache_shape = [
|
||||
max_num_blocks,
|
||||
self.kv_num_heads,
|
||||
self.block_size,
|
||||
self.head_dim // 2,
|
||||
]
|
||||
return key_cache_shape, value_cache_shape
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
metadata = FlashMaskAttentionMetadata()
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_device,
|
||||
forward_meta.decoder_chunk_size_device,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
forward_meta.encoder_batch_ids,
|
||||
forward_meta.encoder_tile_ids_per_batch,
|
||||
forward_meta.encoder_num_blocks_x_cpu,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
(
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
metadata.kv_token_num_cpu,
|
||||
) = pre_cache_len_concat(
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.max_len_tensor_cpu[2],
|
||||
self.block_size,
|
||||
)
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
if self.pd_disaggregation_mode == "per_chunk":
|
||||
if not self.keep_pd_step_flag and not forward_meta.is_dummy_or_profile_run:
|
||||
init_kv_signal_per_query(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_decoder,
|
||||
self.rank,
|
||||
self.num_layers + self.num_layers_draft_model,
|
||||
)
|
||||
elif self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_metadata = open_shm_and_get_meta_signal(
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag
|
||||
)
|
||||
|
||||
if metadata._dtype == "bfloat16":
|
||||
metadata._fuse_kernel_compute_dtype = "bf16"
|
||||
elif metadata._dtype == "float16":
|
||||
metadata._fuse_kernel_compute_dtype = "fp16"
|
||||
elif metadata._dtype == "float32":
|
||||
metadata._fuse_kernel_compute_dtype = "fp32"
|
||||
|
||||
metadata.max_len_tensor_cpu = forward_meta.max_len_tensor_cpu
|
||||
metadata.max_len_tensor_cpu_decoder = paddle.clone(metadata.max_len_tensor_cpu)
|
||||
metadata.max_len_tensor_cpu_decoder[1] = 0
|
||||
|
||||
self.attention_metadata = metadata
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
qkv: paddle.Tensor,
|
||||
compressed_kv: paddle.Tensor,
|
||||
k_pe: paddle.Tensor,
|
||||
layer: Attention,
|
||||
forward_meta: ForwardMeta,
|
||||
):
|
||||
metadata = self.attention_metadata
|
||||
|
||||
if self.pd_disaggregation_mode == "per_query":
|
||||
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index,
|
||||
)
|
||||
|
||||
if metadata.max_len_tensor_cpu[1] > 0:
|
||||
res_encoder = paddle.zeros([qkv.shape[0], self.num_heads * self.head_dim], dtype=qkv.dtype)
|
||||
q, k, v, _ = gqa_rope_write_cache(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.rotary_embs,
|
||||
forward_meta.seq_lens_this_time,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.batch_id_per_token,
|
||||
metadata.block_tables,
|
||||
forward_meta.kv_batch_ids,
|
||||
forward_meta.kv_tile_ids_per_batch,
|
||||
forward_meta.kv_num_blocks_x_cpu,
|
||||
metadata.pre_cache_batch_ids,
|
||||
metadata.pre_cache_tile_ids_per_batch,
|
||||
metadata.pre_cache_num_blocks_cpu,
|
||||
getattr(layer, "q_norm_weight", None),
|
||||
getattr(layer, "k_norm_weight", None),
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
getattr(layer, "cache_v_zp", None),
|
||||
metadata.kv_signal_data_list[layer.layer_id],
|
||||
metadata.kv_token_num_cpu[0].item(),
|
||||
self.max_seq_len,
|
||||
getattr(layer, "rms_norm_eps", 1e-6),
|
||||
getattr(layer, "cache_quant_type_str", "none"),
|
||||
self.rope_3d,
|
||||
)
|
||||
|
||||
flash_mask_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
forward_meta.seq_lens_encoder,
|
||||
res_encoder,
|
||||
forward_meta.attn_mask_offsets,
|
||||
self.num_heads,
|
||||
self.kv_num_heads,
|
||||
self.head_dim,
|
||||
self.max_seq_len,
|
||||
q.shape[0],
|
||||
k.shape[0],
|
||||
)
|
||||
return res_encoder
|
||||
else:
|
||||
raise NotImplementedError("FlashMaskAttentionBackend is not supported for decode.")
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
from .append_attention import append_attention, append_attention_with_output
|
||||
from .flash_mask_attention import flash_mask_attention
|
||||
from .get_block_shape_and_split_kv_block import get_block_shape_and_split_kv_block
|
||||
from .gqa_rope_write_cache import gqa_rope_write_cache
|
||||
from .init_kv_signal_per_query import init_kv_signal_per_query
|
||||
@@ -31,4 +32,5 @@ __all__ = [
|
||||
"gqa_rope_write_cache",
|
||||
"pre_cache_len_concat",
|
||||
"init_kv_signal_per_query",
|
||||
"flash_mask_attention",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
def flash_mask_attention(
|
||||
q: paddle.Tensor,
|
||||
k: paddle.Tensor,
|
||||
v: paddle.Tensor,
|
||||
cu_seqlens_q: paddle.Tensor,
|
||||
cu_seqlens_k: paddle.Tensor,
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
attn_out: paddle.Tensor,
|
||||
attn_mask_offsets: Optional[paddle.Tensor] = None,
|
||||
num_heads: int = 0,
|
||||
kv_num_heads: int = 0,
|
||||
head_dim: int = 128,
|
||||
max_seq_len: int = 0,
|
||||
q_token_num: int = 0,
|
||||
kv_token_num: int = 0,
|
||||
):
|
||||
if current_platform.is_cuda():
|
||||
from fastdeploy.model_executor.ops.gpu import flash_mask_attention
|
||||
|
||||
flash_mask_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
seq_lens_encoder,
|
||||
attn_out,
|
||||
attn_mask_offsets,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
head_dim,
|
||||
max_seq_len,
|
||||
q_token_num,
|
||||
kv_token_num,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -39,6 +39,8 @@ def gqa_rope_write_cache(
|
||||
cache_batch_ids: paddle.Tensor,
|
||||
cache_tile_ids_per_batch: paddle.Tensor,
|
||||
cache_num_blocks: paddle.Tensor,
|
||||
q_norm_weight: Optional[paddle.Tensor] = None,
|
||||
k_norm_weight: Optional[paddle.Tensor] = None,
|
||||
cache_k_quant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_v_quant_scales: Optional[paddle.Tensor] = None,
|
||||
cache_k_dequant_scales: Optional[paddle.Tensor] = None,
|
||||
@@ -48,6 +50,7 @@ def gqa_rope_write_cache(
|
||||
kv_signal_data: Optional[paddle.Tensor] = None,
|
||||
kv_token_num: int = 1,
|
||||
max_seq_len: int = 0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
cache_quant_type: str = "none",
|
||||
rope_3d: bool = False,
|
||||
):
|
||||
@@ -72,6 +75,8 @@ def gqa_rope_write_cache(
|
||||
cache_batch_ids,
|
||||
cache_tile_ids_per_batch,
|
||||
cache_num_blocks,
|
||||
q_norm_weight,
|
||||
k_norm_weight,
|
||||
cache_k_quant_scales,
|
||||
cache_v_quant_scales,
|
||||
cache_k_dequant_scales,
|
||||
@@ -81,6 +86,7 @@ def gqa_rope_write_cache(
|
||||
kv_signal_data,
|
||||
kv_token_num,
|
||||
max_seq_len,
|
||||
rms_norm_eps,
|
||||
cache_quant_type,
|
||||
rope_3d,
|
||||
)
|
||||
|
||||
@@ -28,6 +28,7 @@ class _Backend(enum.Enum):
|
||||
BLOCK_ATTN = enum.auto()
|
||||
PLAS_ATTN = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
FLASH_MASK_ATTN = enum.auto()
|
||||
|
||||
|
||||
class Platform:
|
||||
|
||||
@@ -67,6 +67,9 @@ class CUDAPlatform(Platform):
|
||||
elif selected_backend == _Backend.PLAS_ATTN:
|
||||
logger.info("Using PLAS ATTN backend.")
|
||||
return "fastdeploy.model_executor.layers.attention.PlasAttentionBackend"
|
||||
elif selected_backend == _Backend.FLASH_MASK_ATTN:
|
||||
logger.info("Using FLASH MASK ATTN backend.")
|
||||
return "fastdeploy.model_executor.layers.attention.FlashMaskAttentionBackend"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid attention backend you specified.\n"
|
||||
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from fastdeploy.model_executor.ops.gpu import flash_attention_mask
|
||||
from fastdeploy.model_executor.ops.gpu import flash_mask_attention
|
||||
|
||||
|
||||
class TestFlashMaskAttention(unittest.TestCase):
|
||||
@@ -27,7 +27,7 @@ class TestFlashMaskAttention(unittest.TestCase):
|
||||
self.bsz = 1
|
||||
self.num_head = 8
|
||||
self.num_kv_head = 1
|
||||
self.q_seq_len = 1024
|
||||
self.q_seq_len = 888
|
||||
self.k_seq_len = 1024
|
||||
self.head_dim = 128
|
||||
np.random.seed(self.q_seq_len)
|
||||
@@ -71,7 +71,7 @@ class TestFlashMaskAttention(unittest.TestCase):
|
||||
v_input_pad[0 : v_input.shape[0]] = v_input
|
||||
mask = paddle.to_tensor(mask).astype("int32")
|
||||
|
||||
flash_attention_mask(
|
||||
flash_mask_attention(
|
||||
q_input,
|
||||
k_input,
|
||||
v_input_pad,
|
||||
@@ -88,7 +88,7 @@ class TestFlashMaskAttention(unittest.TestCase):
|
||||
int(k_input.shape[0]),
|
||||
)
|
||||
|
||||
def test_flash_attention_mask(self):
|
||||
def test_flash_mask_attention(self):
|
||||
q_input = np.random.normal(0, 0.5, size=(self.bsz, self.q_seq_len, self.num_head, self.head_dim))
|
||||
k_input = np.random.normal(
|
||||
0, 0.5, size=(self.bsz, self.q_seq_len + self.k_seq_len, self.num_kv_head, self.head_dim)
|
||||
@@ -104,8 +104,8 @@ class TestFlashMaskAttention(unittest.TestCase):
|
||||
mask = np.array([i + 1 for i in range(0, self.q_seq_len)]) + self.k_seq_len
|
||||
mask[text_len : text_len + image_len] = text_len + image_len + self.k_seq_len
|
||||
|
||||
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask)
|
||||
paddle_attn_out = paddle.zeros(naive_attn_out.shape, dtype="bfloat16")
|
||||
naive_attn_out = self.naive_attn(q_input, k_input, v_input, mask).transpose([0, 2, 1, 3])
|
||||
paddle_attn_out = paddle.zeros(q_input.shape, dtype="bfloat16")
|
||||
self.paddle_flash_attn_mask(q_input, k_input, v_input, paddle_attn_out, mask)
|
||||
|
||||
max_diff = float((paddle_attn_out.reshape([-1]) - paddle.to_tensor(naive_attn_out).reshape([-1])).max())
|
||||
|
||||
Reference in New Issue
Block a user