support fa3 rope3d (#3622)

This commit is contained in:
xiaoxiaohehe001
2025-08-27 11:31:29 +08:00
committed by GitHub
parent 85afa72763
commit ad319a87cc
4 changed files with 17 additions and 7 deletions

View File

@@ -37,7 +37,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim) {
const int last_dim,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
@@ -62,6 +63,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
const int64_t emb_idx = ori_seq_id * half_lastdim + h_bias / 2;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
h_bias;
@@ -80,8 +82,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
Load<T, VecSize>(&qkv[base_idx], &src_vec);
// do rope
if (hi < q_num_head + kv_num_head) {
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
const float input_left = static_cast<float>(src_vec[2 * i]);
@@ -118,6 +120,7 @@ void gqa_rotary_qk_split_variable(
const int seq_len,
const int input_output_len,
const int dim_head,
const bool rope_3d,
const cudaStream_t &stream) {
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
constexpr int PackSize = 16 / sizeof(T);
@@ -146,7 +149,8 @@ void gqa_rotary_qk_split_variable(
num_heads,
kv_num_heads,
seq_len,
dim_head);
dim_head,
rope_3d);
}
template <typename T,
@@ -890,7 +894,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor>& kv_signal_data,
const int kv_token_num,
const int max_seq_len,
const std::string& cache_quant_type) {
const std::string& cache_quant_type,
const bool rope_3d) {
typedef PDTraits<paddle::DataType::BFLOAT16> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;
@@ -953,8 +958,9 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
num_heads,
kv_num_heads,
max_seq_len,
rotary_embs.dims()[2],
rope_3d ? rotary_embs.dims()[3] : rotary_embs.dims()[2],
head_dim,
rope_3d,
stream);
if (token_num < kv_token_num) {

View File

@@ -111,7 +111,8 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const paddle::optional<paddle::Tensor> &cache_v_zp,
const paddle::optional<paddle::Tensor> &kv_signal_data,
const int kv_token_num, const int max_seq_len,
const std::string &cache_quant_type);
const std::string &cache_quant_type,
const bool rope_3d);
std::vector<paddle::Tensor>
PreCacheLenConcat(const paddle::Tensor &seq_lens_decoder,