diff --git a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu index 804bbac4e..76f001890 100644 --- a/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu +++ b/custom_ops/gpu_ops/append_attn/gqa_rope_write_cache.cu @@ -232,6 +232,179 @@ void gqa_rotary_qk_split_variable( rms_norm_eps); } +template +__global__ void GQAVariableLengthNeoxPartialRotarySplitKernel( + const T *qkv, + const float *cos_emb, + const float *sin_emb, + const int *batch_id_per_token, + const int *cu_seqlens_q, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int *cu_seqlens_k, + T *qkv_out, + T *q, + T *k, + T *v, + const int64_t elem_cnt, + const int q_num_head, + const int kv_num_head, + const int max_model_len, + const int head_dim, + const int rotary_dim) { + using LoadT = AlignedVector; + using LoadEmbT = AlignedVector; + LoadT src_vec; + LoadT src_vec_right; + LoadEmbT cos_emb_vec; + LoadEmbT sin_emb_vec; + int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y; + int64_t all_warp_num = gridDim.x * blockDim.y; + const int half_rotary_dim = rotary_dim / 2; + const int half_headdim = head_dim / 2; + const int offset = + (q_num_head + kv_num_head * 2) * head_dim; // for all q,k,v + const int all_head_num = elem_cnt / head_dim; +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; + gloabl_hi += all_warp_num) { + int64_t linear_index = + gloabl_hi * head_dim + threadIdx.x * VecSize; // 全局index + const int token_idx = + linear_index / offset; // token id(第几个token,不分qkv) + const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch + + int cache_kv_len = seq_lens_decoder[ori_bi]; + // 这里其实是不需要处理的,但是由于FA3的bug,所以必须! + if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0; + + const int bias = linear_index % offset; + const int hi = bias / head_dim; + const int h_bias = bias % head_dim; + + const int ori_seq_id = + (token_idx - cu_seqlens_q[ori_bi]) + + cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效) + const int64_t base_idx = + token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim + + h_bias; + Load(&qkv[base_idx], &src_vec); + const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id; + int64_t base_split_idx; + T *out_p = nullptr; + if (hi < q_num_head) { + base_split_idx = + token_idx * q_num_head * head_dim + hi * head_dim + h_bias; + out_p = q; + } else if (hi < q_num_head + kv_num_head) { + base_split_idx = kv_write_idx * kv_num_head * head_dim + + (hi - q_num_head) * head_dim + h_bias; + out_p = k; + } else { + out_p = v; + base_split_idx = kv_write_idx * kv_num_head * head_dim + + (hi - q_num_head - kv_num_head) * head_dim + h_bias; + } + + if (hi < q_num_head + kv_num_head) { + if (h_bias < rotary_dim) { + int64_t emb_idx = ori_seq_id * half_rotary_dim; + if (h_bias < half_rotary_dim) { + Load(&qkv[base_idx + half_rotary_dim], &src_vec_right); + emb_idx += h_bias; + } else { + Load(&qkv[base_idx - half_rotary_dim], &src_vec_right); + emb_idx += h_bias - half_rotary_dim; + } + Load(&cos_emb[emb_idx], &cos_emb_vec); + Load(&sin_emb[emb_idx], &sin_emb_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + const float input_left = static_cast(src_vec[i]); + const float input_right = static_cast(src_vec_right[i]); + const float cos_tmp = cos_emb_vec[i]; + const float sin_tmp = sin_emb_vec[i]; + if (h_bias < half_rotary_dim) { + src_vec[i] = + static_cast(input_left * cos_tmp - input_right * sin_tmp); + } else { + src_vec[i] = + static_cast(input_left * cos_tmp + input_right * sin_tmp); + } + } + } + } + + Store(src_vec, &qkv_out[base_idx]); + Store(src_vec, &out_p[base_split_idx]); + } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif +} + +template +void gqa_neox_partial_rotary_qk_split_variable( + T *qkv_out, // [token_num, 3, num_head, head_dim] + T *q, + T *k, + T *v, + const T *qkv_input, + const float *rotary_emb, // [2, 1, seq_len, 1, head_dim / 4] + const int *batch_id_per_token, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int *cu_seqlens_q, + const int *cu_seqlens_k, + const int token_num, + const int num_heads, + const int kv_num_heads, + const int max_model_len, + const int head_dim, + const int rotary_dim, + const cudaStream_t &stream) { + assert(head_dim == 128 && "head_dim must be 128"); + int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim; + + constexpr int HEAD_DIM = 128; + constexpr int PackSize = HEAD_DIM / kWarpSize; + assert(rotary_dim / 2 % PackSize == 0); + const int pack_num = elem_nums / PackSize; + const int blocksize = 128; + int grid_size = 1; + GetNumBlocks<128>(pack_num, &grid_size); + dim3 block_size(kWarpSize, blocksize / kWarpSize); + + const float *cos_emb = rotary_emb; + const float *sin_emb = rotary_emb + max_model_len * rotary_dim / 2; + launchWithPdlWhenEnabled( + GQAVariableLengthNeoxPartialRotarySplitKernel, + grid_size, + block_size, + 0, + stream, + qkv_input, + cos_emb, + sin_emb, + batch_id_per_token, + cu_seqlens_q, + seq_lens_encoder, + seq_lens_decoder, + cu_seqlens_k, + qkv_out, + q, + k, + v, + elem_nums, + num_heads, + kv_num_heads, + max_model_len, + head_dim, + rotary_dim); +} + template GQARopeWriteCacheKernel( const int num_heads = qkv_dims[qkv_dims.size() - 1] / head_dim - 2 * kv_num_heads; const float softmax_scale = 1.f / sqrt(head_dim); + int rotary_dim = head_dim; PADDLE_ENFORCE_EQ(batch_id_per_token.dims().size(), 1); PADDLE_ENFORCE_EQ(batch_id_per_token.dims()[0], token_num); @@ -1171,7 +1345,13 @@ std::vector GQARopeWriteCacheKernel( if (use_neox_rotary_style) { // Note(ZKK) Qwen3 like model // the [0,head_dim/2), [head_dim/2,head_dim) data are totally same! - PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim); + if (rotary_embs.dims()[4] == head_dim) { + rotary_dim = head_dim; + } else { + // for glm partial rotary style + PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim / 4); + rotary_dim = head_dim / 2; + } } else { PADDLE_ENFORCE_EQ(rotary_embs.dims()[4], head_dim / 2); } @@ -1196,23 +1376,45 @@ std::vector GQARopeWriteCacheKernel( {kv_token_num, kv_num_heads, head_dim}, qkv.dtype(), qkv.place()); if (use_neox_rotary_style) { - gqa_rotary_qk_split_variable_qwen3(qkv_out.data(), - q.data(), - k.data(), - v.data(), - qkv.data(), - rotary_embs.data(), - batch_id_per_token.data(), - seq_lens_encoder.data(), - seq_lens_decoder.data(), - cu_seqlens_q.data(), - cu_seqlens_k.data(), - token_num, - num_heads, - kv_num_heads, - max_seq_len, - head_dim, - stream); + if (rotary_dim == head_dim) { + gqa_rotary_qk_split_variable_qwen3(qkv_out.data(), + q.data(), + k.data(), + v.data(), + qkv.data(), + rotary_embs.data(), + batch_id_per_token.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + head_dim, + stream); + } else { + gqa_neox_partial_rotary_qk_split_variable( + qkv_out.data(), + q.data(), + k.data(), + v.data(), + qkv.data(), + rotary_embs.data(), + batch_id_per_token.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + token_num, + num_heads, + kv_num_heads, + max_seq_len, + head_dim, + rotary_dim, + stream); + } } else { gqa_rotary_qk_split_variable( qkv_out.data(),