mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Graph Optimization] Support CUDAGraph Padding + MTP (#4545)
* Support CUDAGraph Padding + MTP * support orther write cache kernel
This commit is contained in:
@@ -27,7 +27,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
@@ -68,10 +68,13 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * gqa_group_size) * head_size;
|
||||
const int half_head_size = head_size / 2;
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_dim; global_hi += all_warp_num) {
|
||||
for (int global_hi = global_warp_idx; global_hi < all_head_dim;
|
||||
global_hi += all_warp_num) {
|
||||
int64_t linear_index = global_hi * head_size + threadIdx.x * VecSize;
|
||||
const int token_id = linear_index / hidden_size;
|
||||
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (ori_bi == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
@@ -84,7 +87,7 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
if (block_idx < 0) {
|
||||
return ; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -102,7 +105,8 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
@@ -136,20 +140,23 @@ __global__ void append_speculate_cache_T_rope_qk_norm_kernel(
|
||||
}
|
||||
if (hi < (num_heads + gqa_group_size)) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_variance = max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
if (hi < num_heads) {
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize],
|
||||
&q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
bias_vec[i] =
|
||||
static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else {
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize],
|
||||
&k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
bias_vec[i] =
|
||||
static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -179,7 +186,7 @@ __global__ void append_clear_cache_int8_block(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||
// block_size, head_size // 2]
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -197,6 +204,7 @@ __global__ void append_clear_cache_int8_block(
|
||||
const int token_id = blockIdx.x;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
@@ -245,7 +253,6 @@ __global__ void append_clear_cache_int8_block(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <int VecSize = 4, int HeadDim = 128>
|
||||
__global__ void append_clear_cache_int4_block(
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||
@@ -253,7 +260,7 @@ __global__ void append_clear_cache_int4_block(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||
// block_size, head_size // 2]
|
||||
const int* __restrict__ seq_lens,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -271,6 +278,7 @@ __global__ void append_clear_cache_int4_block(
|
||||
const int token_id = blockIdx.x;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
@@ -331,7 +339,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ q_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
@@ -370,6 +378,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (ori_bi == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % hidden_size;
|
||||
const int hi = bias / head_size; // q + k + v
|
||||
@@ -382,7 +392,7 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
if (block_idx < 0) {
|
||||
return ; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -400,7 +410,8 @@ __global__ void append_speculate_cache_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
int64_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
@@ -458,7 +469,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
T* __restrict__ value_cache, // [num_blocks, gqa_group_size, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens_decoder, // [bsz]
|
||||
@@ -497,6 +508,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
linear_index += step) {
|
||||
const int token_id = linear_index / half_hidden_size;
|
||||
const int ori_bi = batch_id_per_token[token_id];
|
||||
if (ori_bi == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
if (seq_lens_decoder[ori_bi] == 0) continue;
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
@@ -509,7 +521,7 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
const int* block_table_now = block_tables + ori_bi * max_blocks_per_seq;
|
||||
const int block_idx = block_table_now[write_seq_id / block_size];
|
||||
if (block_idx < 0) {
|
||||
return ; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
}
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
@@ -531,7 +543,8 @@ __global__ void append_speculate_cache_neox_rope_kernel(
|
||||
if (hi < num_heads + gqa_group_size) {
|
||||
// q k rope
|
||||
const int64_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2: emb_idx;
|
||||
int64_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
@@ -591,14 +604,14 @@ template <typename T,
|
||||
int HeadDim = 128,
|
||||
bool IsFP8 = false>
|
||||
__global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
|
||||
const T* __restrict__ quant_qkv, // [num_head, num_heads + 2 *
|
||||
// gqa_group_size, head_size]
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, gqa_group_size,
|
||||
// block_size, head_size // 2]
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -627,6 +640,7 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
const int token_id = blockIdx.x;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
@@ -644,10 +658,12 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
if (head_idx < num_heads) {
|
||||
cache_offset = 0;
|
||||
} else if (head_idx < num_heads + 2 * gqa_group_size) {
|
||||
cache_offset = block_idx * gqa_group_size * block_size + (head_idx - num_heads) % gqa_group_size * block_size + block_offset;
|
||||
cache_offset = block_idx * gqa_group_size * block_size +
|
||||
(head_idx - num_heads) % gqa_group_size * block_size +
|
||||
block_offset;
|
||||
}
|
||||
T *cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T *cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
T* cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T* cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
@@ -675,7 +691,8 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
@@ -688,22 +705,20 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
bias_vec[2 * i] =
|
||||
static_cast<T>(tmp1);
|
||||
bias_vec[2 * i + 1] =
|
||||
static_cast<T>(tmp2);
|
||||
bias_vec[2 * i] = static_cast<T>(tmp1);
|
||||
bias_vec[2 * i + 1] = static_cast<T>(tmp2);
|
||||
}
|
||||
// qk norm
|
||||
if (q_norm_weight) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_variance = max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadOutScaleT q_norm_vec;
|
||||
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
bias_vec[i] = static_cast<T>(static_cast<float>(bias_vec[i]) * row_inv_var * q_norm_vec[i]);
|
||||
bias_vec[i] = static_cast<T>(static_cast<float>(bias_vec[i]) *
|
||||
row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(bias_vec, &qkv_out_now[bias_idx]);
|
||||
@@ -739,7 +754,8 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
const int v_head_idx = head_idx - num_heads - gqa_group_size;
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -754,10 +770,8 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
bias_vec1[0] =
|
||||
static_cast<T>(tmp1);
|
||||
bias_vec1[1] =
|
||||
static_cast<T>(tmp2);
|
||||
bias_vec1[0] = static_cast<T>(tmp1);
|
||||
bias_vec1[1] = static_cast<T>(tmp2);
|
||||
} else {
|
||||
bias_vec1[0] = static_cast<T>(input_left);
|
||||
bias_vec1[1] = static_cast<T>(input_right);
|
||||
@@ -771,10 +785,8 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
bias_vec2[0] =
|
||||
static_cast<T>(tmp1);
|
||||
bias_vec2[1] =
|
||||
static_cast<T>(tmp2);
|
||||
bias_vec2[0] = static_cast<T>(tmp1);
|
||||
bias_vec2[1] = static_cast<T>(tmp2);
|
||||
} else {
|
||||
bias_vec2[0] = static_cast<T>(input_left);
|
||||
bias_vec2[1] = static_cast<T>(input_right);
|
||||
@@ -783,16 +795,18 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
LoadOutScaleT k_norm_vec1, k_norm_vec2;
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8],
|
||||
&k_norm_vec2);
|
||||
// qk norm
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_variance = max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
bias_vec1[i] = static_cast<T>(static_cast<float>(bias_vec1[i]) * row_inv_var * k_norm_vec1[i]);
|
||||
bias_vec2[i] = static_cast<T>(static_cast<float>(bias_vec2[i]) * row_inv_var * k_norm_vec2[i]);
|
||||
bias_vec1[i] = static_cast<T>(static_cast<float>(bias_vec1[i]) *
|
||||
row_inv_var * k_norm_vec1[i]);
|
||||
bias_vec2[i] = static_cast<T>(static_cast<float>(bias_vec2[i]) *
|
||||
row_inv_var * k_norm_vec2[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -805,7 +819,8 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
||||
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
local_max =
|
||||
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
|
||||
scale = __hdiv(448, local_max);
|
||||
@@ -820,8 +835,10 @@ __global__ void append_speculate_cache_fp8_rope_qk_norm_dynamic_kernel(
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, bias_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, bias_vec2[i], max_bound, min_bound);
|
||||
cache_vec[i] = QuantToC8<T, true, IsFP8, RoundType>(
|
||||
scale, bias_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T, true, IsFP8, RoundType>(
|
||||
scale, bias_vec2[i], max_bound, min_bound);
|
||||
}
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const int start_block_16 =
|
||||
@@ -866,7 +883,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -896,6 +913,7 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
const int token_id = blockIdx.x;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
@@ -932,7 +950,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
@@ -994,7 +1013,8 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
T scale;
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -1056,8 +1076,10 @@ __global__ void append_speculate_cache_int8_rope_kernel(
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, bias_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, bias_vec2[i], max_bound, min_bound);
|
||||
cache_vec[i] = QuantToC8<T, true, IsFP8, RoundType>(
|
||||
scale, bias_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T, true, IsFP8, RoundType>(
|
||||
scale, bias_vec2[i], max_bound, min_bound);
|
||||
}
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const int start_block_16 =
|
||||
@@ -1101,7 +1123,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -1131,6 +1153,7 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
const int token_id = blockIdx.x;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
@@ -1170,7 +1193,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
if (qkv_out_scales) {
|
||||
@@ -1267,7 +1291,8 @@ __global__ void append_speculate_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -1482,7 +1507,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -1515,6 +1540,7 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
const int token_id = blockIdx.x;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
@@ -1561,7 +1587,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
@@ -1652,7 +1679,8 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
// &out_scale_vec2);
|
||||
if (head_idx < num_heads + gqa_group_size) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -1759,7 +1787,6 @@ __global__ void append_speculate_cache_int4_rope_kernel(
|
||||
}
|
||||
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
|
||||
} else {
|
||||
|
||||
const uint32_t base_tgt_cache_idx =
|
||||
block_idx * gqa_group_size * HeadDim * half_block_size +
|
||||
kv_head_idx * HeadDim * half_block_size +
|
||||
@@ -1828,7 +1855,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, gqa_group_size,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -1861,6 +1888,7 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
const int token_id = blockIdx.x;
|
||||
|
||||
const int bid = batch_id_per_token[token_id];
|
||||
if (bid == -1) return; // NOTE(gongshaotian): For CUDAGraph padding
|
||||
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
@@ -2000,7 +2028,8 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec2);
|
||||
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -2038,7 +2067,6 @@ __global__ void append_speculate_cache_int4_neox_rope_kernel(
|
||||
right_bias_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
|
||||
|
||||
input_left = static_cast<float>(left_src_vec2[i]);
|
||||
input_right = static_cast<float>(right_src_vec2[i]);
|
||||
cos_tmp = cos_emb_vec2[i];
|
||||
|
||||
Reference in New Issue
Block a user