mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support glm fa3 (#5586)
This commit is contained in:
@@ -232,6 +232,179 @@ void gqa_rotary_qk_split_variable(
|
||||
rms_norm_eps);
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void GQAVariableLengthNeoxPartialRotarySplitKernel(
|
||||
const T *qkv,
|
||||
const float *cos_emb,
|
||||
const float *sin_emb,
|
||||
const int *batch_id_per_token,
|
||||
const int *cu_seqlens_q,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int *cu_seqlens_k,
|
||||
T *qkv_out,
|
||||
T *q,
|
||||
T *k,
|
||||
T *v,
|
||||
const int64_t elem_cnt,
|
||||
const int q_num_head,
|
||||
const int kv_num_head,
|
||||
const int max_model_len,
|
||||
const int head_dim,
|
||||
const int rotary_dim) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadEmbT = AlignedVector<float, VecSize>;
|
||||
LoadT src_vec;
|
||||
LoadT src_vec_right;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
|
||||
int64_t all_warp_num = gridDim.x * blockDim.y;
|
||||
const int half_rotary_dim = rotary_dim / 2;
|
||||
const int half_headdim = head_dim / 2;
|
||||
const int offset =
|
||||
(q_num_head + kv_num_head * 2) * head_dim; // for all q,k,v
|
||||
const int all_head_num = elem_cnt / head_dim;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num;
|
||||
gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index =
|
||||
gloabl_hi * head_dim + threadIdx.x * VecSize; // 全局index
|
||||
const int token_idx =
|
||||
linear_index / offset; // token id(第几个token,不分qkv)
|
||||
const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch
|
||||
|
||||
int cache_kv_len = seq_lens_decoder[ori_bi];
|
||||
// 这里其实是不需要处理的,但是由于FA3的bug,所以必须!
|
||||
if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0;
|
||||
|
||||
const int bias = linear_index % offset;
|
||||
const int hi = bias / head_dim;
|
||||
const int h_bias = bias % head_dim;
|
||||
|
||||
const int ori_seq_id =
|
||||
(token_idx - cu_seqlens_q[ori_bi]) +
|
||||
cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
|
||||
const int64_t base_idx =
|
||||
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
|
||||
h_bias;
|
||||
Load<T, VecSize>(&qkv[base_idx], &src_vec);
|
||||
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
|
||||
int64_t base_split_idx;
|
||||
T *out_p = nullptr;
|
||||
if (hi < q_num_head) {
|
||||
base_split_idx =
|
||||
token_idx * q_num_head * head_dim + hi * head_dim + h_bias;
|
||||
out_p = q;
|
||||
} else if (hi < q_num_head + kv_num_head) {
|
||||
base_split_idx = kv_write_idx * kv_num_head * head_dim +
|
||||
(hi - q_num_head) * head_dim + h_bias;
|
||||
out_p = k;
|
||||
} else {
|
||||
out_p = v;
|
||||
base_split_idx = kv_write_idx * kv_num_head * head_dim +
|
||||
(hi - q_num_head - kv_num_head) * head_dim + h_bias;
|
||||
}
|
||||
|
||||
if (hi < q_num_head + kv_num_head) {
|
||||
if (h_bias < rotary_dim) {
|
||||
int64_t emb_idx = ori_seq_id * half_rotary_dim;
|
||||
if (h_bias < half_rotary_dim) {
|
||||
Load<T, VecSize>(&qkv[base_idx + half_rotary_dim], &src_vec_right);
|
||||
emb_idx += h_bias;
|
||||
} else {
|
||||
Load<T, VecSize>(&qkv[base_idx - half_rotary_dim], &src_vec_right);
|
||||
emb_idx += h_bias - half_rotary_dim;
|
||||
}
|
||||
Load<float, VecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
const float input_left = static_cast<float>(src_vec[i]);
|
||||
const float input_right = static_cast<float>(src_vec_right[i]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
if (h_bias < half_rotary_dim) {
|
||||
src_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
} else {
|
||||
src_vec[i] =
|
||||
static_cast<T>(input_left * cos_tmp + input_right * sin_tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Store<T, VecSize>(src_vec, &qkv_out[base_idx]);
|
||||
Store<T, VecSize>(src_vec, &out_p[base_split_idx]);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void gqa_neox_partial_rotary_qk_split_variable(
|
||||
T *qkv_out, // [token_num, 3, num_head, head_dim]
|
||||
T *q,
|
||||
T *k,
|
||||
T *v,
|
||||
const T *qkv_input,
|
||||
const float *rotary_emb, // [2, 1, seq_len, 1, head_dim / 4]
|
||||
const int *batch_id_per_token,
|
||||
const int *seq_lens_encoder,
|
||||
const int *seq_lens_decoder,
|
||||
const int *cu_seqlens_q,
|
||||
const int *cu_seqlens_k,
|
||||
const int token_num,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int max_model_len,
|
||||
const int head_dim,
|
||||
const int rotary_dim,
|
||||
const cudaStream_t &stream) {
|
||||
assert(head_dim == 128 && "head_dim must be 128");
|
||||
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim;
|
||||
|
||||
constexpr int HEAD_DIM = 128;
|
||||
constexpr int PackSize = HEAD_DIM / kWarpSize;
|
||||
assert(rotary_dim / 2 % PackSize == 0);
|
||||
const int pack_num = elem_nums / PackSize;
|
||||
const int blocksize = 128;
|
||||
int grid_size = 1;
|
||||
GetNumBlocks<128>(pack_num, &grid_size);
|
||||
dim3 block_size(kWarpSize, blocksize / kWarpSize);
|
||||
|
||||
const float *cos_emb = rotary_emb;
|
||||
const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2;
|
||||
launchWithPdlWhenEnabled(
|
||||
GQAVariableLengthNeoxPartialRotarySplitKernel<T, PackSize>,
|
||||
grid_size,
|
||||
block_size,
|
||||
0,
|
||||
stream,
|
||||
qkv_input,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
batch_id_per_token,
|
||||
cu_seqlens_q,
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
cu_seqlens_k,
|
||||
qkv_out,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
elem_nums,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_model_len,
|
||||
head_dim,
|
||||
rotary_dim);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename CacheT,
|
||||
uint32_t HEAD_DIM,
|
||||
@@ -1158,6 +1331,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
const int num_heads =
|
||||
qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads;
|
||||
const float softmax_scale = 1.f / sqrt(head_dim);
|
||||
int rotary_dim = head_dim;
|
||||
|
||||
PADDLE_ENFORCE_EQ(batch_id_per_token.dims().size(), 1);
|
||||
PADDLE_ENFORCE_EQ(batch_id_per_token.dims()[0], token_num);
|
||||
@@ -1171,7 +1345,13 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
if (use_neox_rotary_style) {
|
||||
// Note(ZKK) Qwen3 like model
|
||||
// the [0,head_dim/2), [head_dim/2,head_dim) data are totally same!
|
||||
PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim);
|
||||
if (rotary_embs.dims()[4] == head_dim) {
|
||||
rotary_dim = head_dim;
|
||||
} else {
|
||||
// for glm partial rotary style
|
||||
PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim / 4);
|
||||
rotary_dim = head_dim / 2;
|
||||
}
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim / 2);
|
||||
}
|
||||
@@ -1196,23 +1376,45 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
{kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place());
|
||||
|
||||
if (use_neox_rotary_style) {
|
||||
gqa_rotary_qk_split_variable_qwen3<data_t>(qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
head_dim,
|
||||
stream);
|
||||
if (rotary_dim == head_dim) {
|
||||
gqa_rotary_qk_split_variable_qwen3<data_t>(qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
head_dim,
|
||||
stream);
|
||||
} else {
|
||||
gqa_neox_partial_rotary_qk_split_variable<data_t>(
|
||||
qkv_out.data<data_t>(),
|
||||
q.data<data_t>(),
|
||||
k.data<data_t>(),
|
||||
v.data<data_t>(),
|
||||
qkv.data<data_t>(),
|
||||
rotary_embs.data<float>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
cu_seqlens_k.data<int>(),
|
||||
token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
max_seq_len,
|
||||
head_dim,
|
||||
rotary_dim,
|
||||
stream);
|
||||
}
|
||||
} else {
|
||||
gqa_rotary_qk_split_variable<data_t>(
|
||||
qkv_out.data<data_t>(),
|
||||
|
||||
Reference in New Issue
Block a user