[Others] Remove useless code (#5404)

This commit is contained in:
周周周
2025-12-08 13:59:46 +08:00
committed by GitHub
parent 3066a0c34b
commit 2aea8a3a60
8 changed files with 139 additions and 166 deletions

View File

@@ -28,7 +28,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const float *k_norm_weight,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_encoder,
const int *seq_lens_decoder,
const int *cu_seqlens_k,
T *qkv_out,
@@ -38,8 +38,8 @@ __global__ void GQAVariableLengthRotarySplitKernel(
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int last_dim,
const int max_model_len,
const int head_dim,
const bool rope_3d,
const float rms_norm_eps) {
using LoadT = AlignedVector<T, VecSize>;
@@ -53,30 +53,33 @@ __global__ void GQAVariableLengthRotarySplitKernel(
LoadFloat q_norm_vec, k_norm_vec;
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
int64_t all_warp_num = gridDim.x * blockDim.y;
const int half_lastdim = last_dim / 2;
const int half_headdim = head_dim / 2;
const int offset =
(q_num_head + kv_num_head * 2) * last_dim; // for all q,k,v
const int all_head_num = elem_cnt / last_dim;
(q_num_head + kv_num_head * 2) * head_dim; // for all q,k,v
const int all_head_num = elem_cnt / head_dim;
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num;
gloabl_hi += all_warp_num) {
int64_t linear_index =
gloabl_hi * last_dim + threadIdx.x * VecSize; // 全局index
gloabl_hi * head_dim + threadIdx.x * VecSize; // 全局index
const int token_idx =
linear_index / offset; // token id(第几个token,不分qkv)
const int ori_bi = batch_id_per_token[token_idx]; // 第几个batch
if (seq_lens[ori_bi] == 0) continue;
int cache_kv_len = seq_lens_decoder[ori_bi];
// 这里其实是不需要处理的但是由于FA3的bug所以必须
if (seq_lens_encoder[ori_bi] == 0) cache_kv_len = 0;
const int bias = linear_index % offset;
const int hi = bias / last_dim;
const int h_bias = bias % last_dim;
const int hi = bias / head_dim;
const int h_bias = bias % head_dim;
const int ori_seq_id =
(token_idx - cu_seqlens_q[ori_bi]) +
seq_lens_decoder
[ori_bi]; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
cache_kv_len; // 在当前seq中的id(拼接了seq到一个batch的情况下有效)
const int64_t emb_idx =
ori_seq_id * half_lastdim + h_bias / 2; // embedding的id
ori_seq_id * half_headdim + h_bias / 2; // embedding的id
const int64_t base_idx =
token_idx * (q_num_head + 2 * kv_num_head) * last_dim + hi * last_dim +
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
h_bias;
Load<T, VecSize>(&qkv[base_idx], &src_vec);
const int kv_write_idx = cu_seqlens_k[ori_bi] + ori_seq_id;
@@ -84,21 +87,21 @@ __global__ void GQAVariableLengthRotarySplitKernel(
T *out_p = nullptr;
if (hi < q_num_head) {
base_split_idx =
token_idx * q_num_head * last_dim + hi * last_dim + h_bias;
token_idx * q_num_head * head_dim + hi * head_dim + h_bias;
out_p = q;
} else if (hi < q_num_head + kv_num_head) {
base_split_idx = kv_write_idx * kv_num_head * last_dim +
(hi - q_num_head) * last_dim + h_bias;
base_split_idx = kv_write_idx * kv_num_head * head_dim +
(hi - q_num_head) * head_dim + h_bias;
out_p = k;
} else {
out_p = v;
base_split_idx = kv_write_idx * kv_num_head * last_dim +
(hi - q_num_head - kv_num_head) * last_dim + h_bias;
base_split_idx = kv_write_idx * kv_num_head * head_dim +
(hi - q_num_head - kv_num_head) * head_dim + h_bias;
}
// TODO check this correct or not
int64_t new_emb_idx =
rope_3d ? emb_idx + ori_bi * last_dim * seq_len : emb_idx;
rope_3d ? emb_idx + ori_bi * head_dim * max_model_len : emb_idx;
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
@@ -122,7 +125,7 @@ __global__ void GQAVariableLengthRotarySplitKernel(
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2); // 单个head的标准差
if (hi < q_num_head + kv_num_head) { // only q and k need norm
float row_variance = max(warp_m2 / last_dim, 0.0f);
float row_variance = max(warp_m2 / head_dim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
if (hi < q_num_head) {
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize],
@@ -165,12 +168,12 @@ __global__ void GQAVariableLengthRotarySplitKernel(
template <typename T>
void gqa_rotary_qk_split_variable(
T *qkv_out, // [token_num, 3, num_head, dim_head]
T *qkv_out, // [token_num, 3, num_head, head_dim]
T *q,
T *k,
T *v,
const T *qkv_input,
const float *rotary_emb, // [2, 1, 1, seq_len, dim_head / 2]
const float *rotary_emb, // [2, 1, 1, seq_len, head_dim / 2]
const float *q_norm_weight,
const float *k_norm_weight,
const int *batch_id_per_token,
@@ -181,14 +184,14 @@ void gqa_rotary_qk_split_variable(
const int token_num,
const int num_heads,
const int kv_num_heads,
const int seq_len,
const int max_model_len,
const int input_output_len,
const int dim_head,
const int head_dim,
const bool rope_3d,
const float rms_norm_eps,
const cudaStream_t &stream) {
assert(dim_head == 128 && "dim_head must be 128");
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * dim_head;
assert(head_dim == 128 && "head_dim must be 128");
int64_t elem_nums = token_num * (num_heads + 2 * kv_num_heads) * head_dim;
constexpr int HEAD_DIM = 128;
constexpr int PackSize = HEAD_DIM / kWarpSize;
@@ -199,7 +202,7 @@ void gqa_rotary_qk_split_variable(
dim3 block_size(kWarpSize, blocksize / kWarpSize);
const float *cos_emb = rotary_emb;
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
const float *sin_emb = rotary_emb + input_output_len * head_dim / 2;
launchWithPdlWhenEnabled(GQAVariableLengthRotarySplitKernel<T, PackSize>,
grid_size,
block_size,
@@ -222,8 +225,8 @@ void gqa_rotary_qk_split_variable(
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
max_model_len,
head_dim,
rope_3d,
rms_norm_eps);
}
@@ -1163,9 +1166,6 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
meta_data.block_size = block_size;
meta_data.batch_size = seq_lens_this_time.dims()[0];
phi::GPUContext *dev_ctx = static_cast<phi::GPUContext *>(
phi::DeviceContextPool::Instance().Get(qkv.place()));
auto stream = qkv.stream();
paddle::Tensor qkv_out = GetEmptyTensor(qkv.dims(), qkv.dtype(), qkv.place());
paddle::Tensor q = GetEmptyTensor(

View File

@@ -16,25 +16,26 @@
#include "paddle/extension.h"
#include "paddle/phi/core/memory/memcpy.h"
__global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_this_time,
int* __restrict__ cu_seqlens_k,
int* __restrict__ batch_ids,
int* __restrict__ tile_ids_per_batch,
int* __restrict__ num_blocks_x,
int* __restrict__ kv_token_num,
const int bsz,
const int num_row_per_block) {
__global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_encoder,
const int* __restrict__ seq_lens_decoder,
const int* __restrict__ seq_lens_this_time,
int* __restrict__ cu_seqlens_k,
int* __restrict__ batch_ids,
int* __restrict__ tile_ids_per_batch,
int* __restrict__ num_blocks_x,
int* __restrict__ kv_token_num,
const int bsz,
const int num_row_per_block) {
if (threadIdx.x == 0) {
int gridx = 0;
int index = 0;
int total_tokens = 0;
cu_seqlens_k[0] = 0;
for (uint32_t bid = 0; bid < bsz; bid++) {
int cache_len = seq_lens_decoder[bid];
const int q_len = seq_lens_this_time[bid];
if (q_len <= 0) {
cache_len = 0;
int cache_len = 0;
if (seq_lens_encoder[bid] > 0) {
// only deal with chunked prefill case.
cache_len = seq_lens_decoder[bid];
}
const int loop_times = div_up(cache_len, num_row_per_block);
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
@@ -42,6 +43,7 @@ __global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder,
tile_ids_per_batch[index++] = tile_id;
}
gridx += loop_times;
const int q_len = seq_lens_this_time[bid];
total_tokens += (cache_len + q_len);
cu_seqlens_k[bid + 1] = total_tokens;
}
@@ -51,6 +53,7 @@ __global__ void pre_cache_len_concat(const int* __restrict__ seq_lens_decoder,
}
std::vector<paddle::Tensor> PreCacheLenConcat(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const int max_dec_len,
@@ -58,45 +61,43 @@ std::vector<paddle::Tensor> PreCacheLenConcat(
auto stream = seq_lens_decoder.stream();
auto place = seq_lens_decoder.place();
int bsz = seq_lens_this_time.shape()[0];
const uint32_t max_tile_size_per_bs_pre_cache = div_up(max_dec_len, block_size);
const uint32_t max_tile_size_per_bs_pre_cache =
div_up(max_dec_len, block_size);
paddle::Tensor cu_seqlens_k = GetEmptyTensor(
{bsz + 1},
paddle::DataType::INT32,
place);
paddle::Tensor cu_seqlens_k =
GetEmptyTensor({bsz + 1}, paddle::DataType::INT32, place);
paddle::Tensor pre_cache_batch_ids = GetEmptyTensor(
{bsz * max_tile_size_per_bs_pre_cache},
paddle::DataType::INT32,
place);
{bsz * max_tile_size_per_bs_pre_cache}, paddle::DataType::INT32, place);
paddle::Tensor pre_cache_tile_ids_per_batch = GetEmptyTensor(
{bsz * max_tile_size_per_bs_pre_cache},
paddle::DataType::INT32,
place);
{bsz * max_tile_size_per_bs_pre_cache}, paddle::DataType::INT32, place);
paddle::Tensor pre_cache_num_blocks =
GetEmptyTensor({1}, paddle::DataType::INT32, place);
GetEmptyTensor({1}, paddle::DataType::INT32, place);
paddle::Tensor kv_token_num =
GetEmptyTensor({1}, paddle::DataType::INT32, place);
GetEmptyTensor({1}, paddle::DataType::INT32, place);
pre_cache_len_concat<<<1, 32, 0, stream>>>(
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seqlens_k.data<int>(),
pre_cache_batch_ids.data<int>(),
pre_cache_tile_ids_per_batch.data<int>(),
pre_cache_num_blocks.data<int>(),
kv_token_num.data<int>(),
bsz,
block_size
);
paddle::Tensor pre_cache_num_blocks_cpu = pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false);
paddle::Tensor kv_token_num_cpu = kv_token_num.copy_to(paddle::CPUPlace(), false);
seq_lens_encoder.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
cu_seqlens_k.data<int>(),
pre_cache_batch_ids.data<int>(),
pre_cache_tile_ids_per_batch.data<int>(),
pre_cache_num_blocks.data<int>(),
kv_token_num.data<int>(),
bsz,
block_size);
paddle::Tensor pre_cache_num_blocks_cpu =
pre_cache_num_blocks.copy_to(paddle::CPUPlace(), false);
paddle::Tensor kv_token_num_cpu =
kv_token_num.copy_to(paddle::CPUPlace(), false);
return {cu_seqlens_k,
pre_cache_batch_ids,
pre_cache_tile_ids_per_batch,
pre_cache_num_blocks_cpu, /*cpu*/
kv_token_num_cpu /*cpu*/
};
return {
cu_seqlens_k,
pre_cache_batch_ids,
pre_cache_tile_ids_per_batch,
pre_cache_num_blocks_cpu, /*cpu*/
kv_token_num_cpu /*cpu*/
};
}
std::vector<paddle::DataType> PreCacheLenConcatInferDtype(
@@ -121,15 +122,13 @@ std::vector<std::vector<int64_t>> PreCacheLenConcatInferShape(
}
PD_BUILD_STATIC_OP(pre_cache_len_concat)
.Inputs({"seq_lens_decoder",
"seq_lens_this_time"})
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
.Outputs({"cu_seqlens_k",
"pre_cache_batch_ids",
"pre_cache_tile_ids_per_batch",
"pre_cache_num_blocks_cpu", /*cpu*/
"kv_token_num_cpu"}) /*cpu*/
.Attrs({"max_dec_len: int",
"block_size: int"})
"kv_token_num_cpu"}) /*cpu*/
.Attrs({"max_dec_len: int", "block_size: int"})
.SetKernelFn(PD_KERNEL(PreCacheLenConcat))
.SetInferShapeFn(PD_INFER_SHAPE(PreCacheLenConcatInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(PreCacheLenConcatInferDtype));

View File

@@ -194,6 +194,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
const bool rope_3d);
std::vector<paddle::Tensor> PreCacheLenConcat(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const int max_dec_len,