FA3 support qwen3 (#5441)

This commit is contained in:
周周周
2025-12-09 16:16:16 +08:00
committed by GitHub
parent 83ea9646f9
commit 31410415db
6 changed files with 242 additions and 44 deletions

View File

@@ -17,6 +17,7 @@
#include "paddle/extension.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/memory/memcpy.h"
#include "qwen3_rope.h"
#include "remote_cache_kv_ipc.h"
template <typename T, int VecSize = 1>
@@ -173,7 +174,7 @@ void gqa_rotary_qk_split_variable(
T *k,
T *v,
const T *qkv_input,
const float *rotary_emb, // [2, 1, 1, seq_len, head_dim / 2]
const float *rotary_emb, // [2, 1, seq_len, 1, head_dim / 2]
const float *q_norm_weight,
const float *k_norm_weight,
const int *batch_id_per_token,
@@ -1136,6 +1137,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const int kv_token_num,
const int max_seq_len,
const float rms_norm_eps,
const bool use_neox_rotary_style,
const std::string &cache_quant_type,
const bool rope_3d) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
@@ -1157,6 +1159,24 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads;
const float softmax_scale = 1.f / sqrt(head_dim);
PADDLE_ENFORCE_EQ(batch_id_per_token.dims().size(), 1);
PADDLE_ENFORCE_EQ(batch_id_per_token.dims()[0], token_num);
if (!rope_3d) {
PADDLE_ENFORCE_EQ(rotary_embs.dims().size(), 5);
PADDLE_ENFORCE_EQ(rotary_embs.dims()[0], 2);
PADDLE_ENFORCE_EQ(rotary_embs.dims()[1], 1);
PADDLE_ENFORCE_EQ(rotary_embs.dims()[2], max_seq_len);
PADDLE_ENFORCE_EQ(rotary_embs.dims()[3], 1);
if (use_neox_rotary_style) {
// Note(ZKK) Qwen3 like model
// the [0,head_dim/2), [head_dim/2,head_dim) data are totally same!
PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim);
} else {
PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim / 2);
}
}
AppendAttnMetaData meta_data;
meta_data.token_nums = token_num;
meta_data.kv_num_heads = kv_num_heads;
@@ -1175,30 +1195,49 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
paddle::Tensor v = GetEmptyTensor(
{kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place());
// rope
gqa_rotary_qk_split_variable<data_t>(
qkv_out.data<data_t>(),
q.data<data_t>(),
k.data<data_t>(),
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rope_3d,
rms_norm_eps,
stream);
if (use_neox_rotary_style) {
gqa_rotary_qk_split_variable_qwen3<data_t>(qkv_out.data<data_t>(),
q.data<data_t>(),
k.data<data_t>(),
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
head_dim,
stream);
} else {
gqa_rotary_qk_split_variable<data_t>(
qkv_out.data<data_t>(),
q.data<data_t>(),
k.data<data_t>(),
v.data<data_t>(),
qkv.data<data_t>(),
rotary_embs.data<float>(),
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
batch_id_per_token.data<int>(),
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
token_num,
num_heads,
kv_num_heads,
max_seq_len,
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rope_3d,
rms_norm_eps,
stream);
}
if (token_num < kv_token_num) {
AppendCacheKV<data_t, 128, 64>(key_cache,
@@ -1347,6 +1386,7 @@ PD_BUILD_STATIC_OP(gqa_rope_write_cache)
.Attrs({"kv_token_num: int",
"max_seq_len: int",
"rms_norm_eps: float",
"use_neox_rotary_style: bool",
"cache_quant_type: std::string",
"rope_3d: bool"})
.SetKernelFn(PD_KERNEL(GQARopeWriteCacheKernel));

View File

@@ -0,0 +1,167 @@
#include "encoder_write_cache_with_rope_impl.cuh"
#include "helper.h"
#include "paddle/extension.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/memory/memcpy.h"
#include "remote_cache_kv_ipc.h"
template <typename T, int VecSize = 1>
__global__ void GQAVariableLengthRotarySplitKernel_Qwen3(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int *cu_seqlens_k,
T *qkv_out,
T *q,
T *k,
T *v,
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int max_model_len,
const int head_dim) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
const int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int offset = (q_num_head + kv_num_head * 2) * (head_dim / 2);
const int64_t loop_times = elem_cnt / 2;
for (int64_t linear_index = global_thread_idx * VecSize;
linear_index < loop_times;
linear_index += gridDim.x * blockDim.x * VecSize) {
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch
int cache_kv_len = seq_lens_decoder[ori_bi];
// 这里其实是不需要处理的但是由于FA3的bug所以必须
if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0;
const int bias = linear_index % offset;
const int hi = bias / (head_dim / 2);
const int h_bias = bias % (head_dim / 2);
// we should handle token_idx, hi 头 的 h_bias 部分!
const int ori_seq_id =
(token_idx - cu_seqlens_q[ori_bi]) +
cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
const int half_headdim = head_dim / 2;
const int64_t emb_idx = ori_seq_id * head_dim + h_bias; // embedding的id
const int64_t read_idx =
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
h_bias;
LoadT src_vec0;
LoadT src_vec1;
Load<T, VecSize>(&qkv[read_idx], &src_vec0);
Load<T, VecSize>(&qkv[read_idx + 64], &src_vec1);
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
int64_t base_split_idx;
T *out_p = nullptr;
if (hi < q_num_head) {
base_split_idx =
token_idx * q_num_head * head_dim + hi * head_dim + h_bias;
out_p = q;
} else if (hi < q_num_head + kv_num_head) {
base_split_idx = kv_write_idx * kv_num_head * head_dim +
(hi - q_num_head) * head_dim + h_bias;
out_p = k;
} else {
out_p = v;
base_split_idx = kv_write_idx * kv_num_head * head_dim +
(hi - q_num_head - kv_num_head) * head_dim + h_bias;
}
// TODO check this correct or not
int64_t new_emb_idx = emb_idx;
if (hi < q_num_head + kv_num_head) {
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
float input_left = static_cast<float>(src_vec0[i]);
float input_right = static_cast<float>(src_vec1[i]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
src_vec0[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
src_vec1[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
}
Store<T, VecSize>(src_vec0, &qkv_out[read_idx]);
Store<T, VecSize>(src_vec0, &out_p[base_split_idx]);
Store<T, VecSize>(src_vec1, &qkv_out[read_idx + 64]);
Store<T, VecSize>(src_vec1, &out_p[base_split_idx + 64]);
}
}
template <typename T>
void gqa_rotary_qk_split_variable_qwen3(T *qkv_out,
T *q,
T *k,
T *v,
const T *qkv_input,
const float *rotary_emb,
const int *batch_id_per_token,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int *cu_seqlens_q,
const int *cu_seqlens_k,
const int token_num,
const int num_heads,
const int kv_num_heads,
const int max_model_len,
const int head_dim,
const cudaStream_t &stream) {
assert(head_dim == 128 && "head_dim must be 128");
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = 8;
const int pack_num = elem_nums / PackSize;
const int blocksize = 128;
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_size(128);
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + max_model_len * head_dim;
launchWithPdlWhenEnabled(
GQAVariableLengthRotarySplitKernel_Qwen3<T, PackSize>,
grid_size,
block_size,
0,
stream,
qkv_input,
cos_emb,
sin_emb,
batch_id_per_token,
cu_seqlens_q,
seq_lens_encoder,
seq_lens_decoder,
cu_seqlens_k,
qkv_out,
q,
k,
v,
elem_nums,
num_heads,
kv_num_heads,
max_model_len,
head_dim);
}

View File

@@ -190,6 +190,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const int kv_token_num,
const int max_seq_len,
const float rms_norm_eps,
const bool use_neox_rotary_style,
const std::string& cache_quant_type,
const bool rope_3d);