mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
support static C8 (#4568)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled
This commit is contained in:
@@ -18,19 +18,16 @@
|
||||
#include "mma_tensor_op.cuh"
|
||||
#include "utils.cuh"
|
||||
|
||||
|
||||
// Note(ZKK)
|
||||
// This function is very easy!
|
||||
// just make HeadDim data to be new HeadDim data!
|
||||
|
||||
template <typename T, int VecSize=8, int HEAD_DIM=128, int NUM_THREADS=32>
|
||||
__device__ __forceinline__ void apply_rope(
|
||||
const T* input,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
T* output,
|
||||
const int thread_id) {
|
||||
|
||||
template <typename T, int VecSize = 8, int HEAD_DIM = 128, int NUM_THREADS = 32>
|
||||
__device__ __forceinline__ void apply_rope(const T* input,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
T* output,
|
||||
const int thread_id) {
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
@@ -43,38 +40,38 @@ __device__ __forceinline__ void apply_rope(
|
||||
LoadEmbT sin_emb_vec;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = thread_id * VecSize; head_bias < HEAD_DIM; head_bias += NUM_THREADS * VecSize) {
|
||||
Load<T, VecSize>(&input[head_bias], &src_vec);
|
||||
const uint32_t emb_idx = head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
for (uint32_t head_bias = thread_id * VecSize; head_bias < HEAD_DIM;
|
||||
head_bias += NUM_THREADS * VecSize) {
|
||||
Load<T, VecSize>(&input[head_bias], &src_vec);
|
||||
const uint32_t emb_idx = head_bias / 2;
|
||||
Load<float, HalfVecSize>(&cos_emb[emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &output[head_bias]);
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &output[head_bias]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -112,7 +109,8 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
|
||||
const int half_head_size = head_size / 2;
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
|
||||
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim;
|
||||
gloabl_hi += all_warp_num) {
|
||||
int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize;
|
||||
const int ori_bi = linear_index / hidden_size;
|
||||
const int bias = linear_index % hidden_size;
|
||||
@@ -136,7 +134,8 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
@@ -162,20 +161,21 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
out_vec[2 * i + 1] = src_vec[2 * i + 1];
|
||||
}
|
||||
}
|
||||
if (hi < (num_heads + kv_num_heads)) { // q k
|
||||
if (hi < (num_heads + kv_num_heads)) { // q k
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / head_size, 0.0f);
|
||||
float row_variance = max(warp_m2 / head_size, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
if (hi < num_heads) { // q
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
if (hi < num_heads) { // q
|
||||
Load<float, VecSize>(&q_norm_weight[threadIdx.x * VecSize],
|
||||
&q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
} else { // k
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
|
||||
#pragma unroll
|
||||
} else { // k
|
||||
Load<float, VecSize>(&k_norm_weight[threadIdx.x * VecSize],
|
||||
&k_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(tmp_vec[i] * row_inv_var * k_norm_vec[i]);
|
||||
}
|
||||
@@ -197,7 +197,6 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
|
||||
Store<T, VecSize>(out_vec, &value_cache[tgt_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,12 +204,12 @@ template <typename T, int VecSize = 1>
|
||||
__global__ void append_decode_cache_T_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -266,7 +265,8 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
@@ -316,7 +316,7 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -325,8 +325,8 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
const float* __restrict__ sin_emb,
|
||||
const float* __restrict__ qkv_out_scales, // [num_head + 2 *
|
||||
// kv_num_heads, dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
@@ -382,7 +382,8 @@ __global__ void append_decode_cache_T_rope_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + h_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
@@ -437,12 +438,14 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
|
||||
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1,
|
||||
// rotary_dim/2]
|
||||
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1,
|
||||
// rotary_dim/2]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
@@ -479,7 +482,7 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
const int bias = linear_index % half_hidden_size;
|
||||
const int hi = bias / half_head_size; // q + k + v
|
||||
const int h_bias = bias % half_head_size;
|
||||
if (hi < num_heads && h_bias >= half_rotary_dim){
|
||||
if (hi < num_heads && h_bias >= half_rotary_dim) {
|
||||
continue;
|
||||
}
|
||||
const int start_token_idx = cu_seqlens_q[ori_bi];
|
||||
@@ -495,12 +498,12 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
uint32_t ori_idx_left =
|
||||
start_token_idx * hidden_size + hi * head_size + h_bias;
|
||||
uint32_t ori_idx_right = ori_idx_left + half_head_size;
|
||||
if (hi < num_heads){
|
||||
if (hi < num_heads) {
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else if (hi < num_heads + kv_num_heads){
|
||||
if (h_bias < half_rotary_dim){
|
||||
} else if (hi < num_heads + kv_num_heads) {
|
||||
if (h_bias < half_rotary_dim) {
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
} else {
|
||||
ori_idx_left = ori_idx_left + half_rotary_dim;
|
||||
ori_idx_right = ori_idx_left + half_rotary_dim;
|
||||
}
|
||||
@@ -512,8 +515,9 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
if (h_bias < half_rotary_dim){
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
if (h_bias < half_rotary_dim) {
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
@@ -550,7 +554,7 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
if (h_bias < half_rotary_dim) {
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}else{
|
||||
} else {
|
||||
tgt_idx_left = tgt_idx_left + half_rotary_dim;
|
||||
tgt_idx_right = tgt_idx_left + half_rotary_dim;
|
||||
}
|
||||
@@ -573,7 +577,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -633,7 +637,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
@@ -686,7 +691,7 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
|
||||
// head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -694,8 +699,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
const float* __restrict__ sin_emb,
|
||||
const float* __restrict__ qkv_out_scales, // [num_head + 2 *
|
||||
// kv_num_heads, dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
@@ -760,7 +765,8 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
if (hi < num_heads + kv_num_heads) {
|
||||
// q k rope
|
||||
const uint32_t emb_idx = write_seq_id * head_size + h_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
}
|
||||
@@ -810,7 +816,13 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=true>
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = true,
|
||||
bool IsDynamic = true>
|
||||
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -819,7 +831,7 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -858,15 +870,6 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
int cache_offset;
|
||||
if (head_idx < num_heads) {
|
||||
cache_offset = 0;
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset;
|
||||
}
|
||||
T *cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T *cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
@@ -891,7 +894,8 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
const uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
@@ -905,22 +909,20 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(tmp2);
|
||||
out_vec[2 * i] = static_cast<T>(tmp1);
|
||||
out_vec[2 * i + 1] = static_cast<T>(tmp2);
|
||||
}
|
||||
// qk norm
|
||||
if (q_norm_weight) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_variance = max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadOutScaleT q_norm_vec;
|
||||
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * q_norm_vec[i]);
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) *
|
||||
row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
@@ -985,7 +987,8 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
const int v_head_idx = head_idx - num_heads - kv_num_heads;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
const uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -1000,10 +1003,8 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec1[0] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec1[1] =
|
||||
static_cast<T>(tmp2);
|
||||
out_vec1[0] = static_cast<T>(tmp1);
|
||||
out_vec1[1] = static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec1[0] = src_vec1[0];
|
||||
out_vec1[1] = src_vec1[1];
|
||||
@@ -1028,45 +1029,67 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
LoadOutScaleT k_norm_vec1, k_norm_vec2;
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8],
|
||||
&k_norm_vec2);
|
||||
// qk norm
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_variance = max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
out_vec1[i] = static_cast<T>(static_cast<float>(out_vec1[i]) * row_inv_var * k_norm_vec1[i]);
|
||||
out_vec2[i] = static_cast<T>(static_cast<float>(out_vec2[i]) * row_inv_var * k_norm_vec2[i]);
|
||||
out_vec1[i] = static_cast<T>(static_cast<float>(out_vec1[i]) *
|
||||
row_inv_var * k_norm_vec1[i]);
|
||||
out_vec2[i] = static_cast<T>(static_cast<float>(out_vec2[i]) *
|
||||
row_inv_var * k_norm_vec2[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// reduce max, 1 head per warp
|
||||
T local_max = -INFINITY;
|
||||
if constexpr (IsDynamic) {
|
||||
// reduce max, 1 head per warp
|
||||
T local_max = -INFINITY;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
local_max = __hmax(local_max, __habs(out_vec1[i]));
|
||||
local_max = __hmax(local_max, __habs(out_vec2[i]));
|
||||
}
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
local_max = __hmax(local_max, __habs(out_vec1[i]));
|
||||
local_max = __hmax(local_max, __habs(out_vec2[i]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
||||
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
for (int m_offset = 16; m_offset > 0; m_offset /= 2) {
|
||||
local_max =
|
||||
__hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
scale = __hdiv(448, local_max);
|
||||
|
||||
scale = __hdiv(448, local_max);
|
||||
|
||||
if (lane_id == 0) {
|
||||
int cache_offset;
|
||||
if (head_idx < num_heads) {
|
||||
cache_offset = 0;
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
cache_offset = block_idx * kv_num_heads * block_size +
|
||||
(head_idx - num_heads) % kv_num_heads * block_size +
|
||||
block_offset;
|
||||
}
|
||||
T* cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T* cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
if (lane_id == 0) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
||||
} else {
|
||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
||||
scale = __ldg(&cache_k_scale[kv_head_idx]);
|
||||
} else {
|
||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
||||
scale = __ldg(&cache_v_scale[kv_head_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec2[i], max_bound, min_bound);
|
||||
cache_vec[i] = QuantToC8<T, true, IsFP8, RoundType>(
|
||||
scale, out_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T, true, IsFP8, RoundType>(
|
||||
scale, out_vec2[i], max_bound, min_bound);
|
||||
}
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const int start_block_16 =
|
||||
@@ -1097,7 +1120,12 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false>
|
||||
__global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -1106,7 +1134,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1144,17 +1172,18 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
const T* qkv_now =
|
||||
quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
T* qkv_out_now =
|
||||
qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
|
||||
uint32_t emb_offset = write_seq_id * half_head_size;
|
||||
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
|
||||
apply_rope<T, VecSize, HeadDim, 32>(
|
||||
qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
apply_rope<T, VecSize, HeadDim, 32>(qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
@@ -1213,11 +1242,14 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
T scale = T(1.0f);
|
||||
const int k_head_idx = head_idx - num_heads;
|
||||
const int v_head_idx = head_idx - num_heads - kv_num_heads;
|
||||
const T *cache_k_scale_cur = cache_k_scale + k_head_idx * HeadDim + head_bias;
|
||||
const T *cache_v_scale_cur = cache_v_scale + v_head_idx * HeadDim + head_bias;
|
||||
const T* cache_k_scale_cur =
|
||||
cache_k_scale + k_head_idx * HeadDim + head_bias;
|
||||
const T* cache_v_scale_cur =
|
||||
cache_v_scale + v_head_idx * HeadDim + head_bias;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -1250,9 +1282,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
out_vec1[0] =
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scale_cur[0]));
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
|
||||
float(cache_k_scale_cur[0]));
|
||||
out_vec1[1] =
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scale_cur[1]));
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
|
||||
float(cache_k_scale_cur[1]));
|
||||
} else {
|
||||
out_vec1[0] = static_cast<T>(input_left * float(cache_v_scale_cur[0]));
|
||||
out_vec1[1] = static_cast<T>(input_right * float(cache_v_scale_cur[1]));
|
||||
@@ -1278,9 +1312,11 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
out_vec2[0] =
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scale_cur[8]));
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
|
||||
float(cache_k_scale_cur[8]));
|
||||
out_vec2[1] =
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scale_cur[9]));
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
|
||||
float(cache_k_scale_cur[9]));
|
||||
} else {
|
||||
out_vec2[0] = static_cast<T>(input_left * float(cache_v_scale_cur[8]));
|
||||
out_vec2[1] = static_cast<T>(input_right * float(cache_v_scale_cur[9]));
|
||||
@@ -1288,8 +1324,10 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec2[i], max_bound, min_bound);
|
||||
cache_vec[i] = QuantToC8<T, true, IsFP8, RoundType>(
|
||||
scale, out_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T, true, IsFP8, RoundType>(
|
||||
scale, out_vec2[i], max_bound, min_bound);
|
||||
}
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const int start_block_16 =
|
||||
@@ -1320,7 +1358,12 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
|
||||
template <typename T,
|
||||
int VecSize = 4,
|
||||
int RoundType = 0,
|
||||
int HeadDim = 128,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false>
|
||||
__global__ void append_decode_cache_int8_rope_kernel(
|
||||
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
@@ -1329,7 +1372,7 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1337,8 +1380,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
const float* __restrict__ sin_emb,
|
||||
const float* __restrict__ qkv_out_scales, // [num_head + 2 *
|
||||
// kv_num_heads, dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const T* __restrict__ cache_k_scales,
|
||||
const T* __restrict__ cache_v_scales,
|
||||
const int max_seq_len,
|
||||
@@ -1398,7 +1441,8 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
@@ -1490,11 +1534,14 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
T scale = T(1.0f);
|
||||
const int k_head_idx = head_idx - num_heads;
|
||||
const int v_head_idx = head_idx - num_heads - kv_num_heads;
|
||||
const T *cache_k_scale_cur = cache_k_scales + k_head_idx * HeadDim + head_bias;
|
||||
const T *cache_v_scale_cur = cache_v_scales + v_head_idx * HeadDim + head_bias;
|
||||
const T* cache_k_scale_cur =
|
||||
cache_k_scales + k_head_idx * HeadDim + head_bias;
|
||||
const T* cache_v_scale_cur =
|
||||
cache_v_scales + v_head_idx * HeadDim + head_bias;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -1533,12 +1580,15 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
bias_vec1[0] =
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scale_cur[0]));
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
|
||||
float(cache_k_scale_cur[0]));
|
||||
bias_vec1[1] =
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scale_cur[1]));
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
|
||||
float(cache_k_scale_cur[1]));
|
||||
} else {
|
||||
bias_vec1[0] = static_cast<T>(input_left * float(cache_v_scale_cur[0]));
|
||||
bias_vec1[1] = static_cast<T>(input_right * float(cache_v_scale_cur[1]));
|
||||
bias_vec1[1] =
|
||||
static_cast<T>(input_right * float(cache_v_scale_cur[1]));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1563,16 +1613,19 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
bias_vec2[1] = static_cast<T>(input_right);
|
||||
}
|
||||
} else {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
bias_vec2[0] =
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) * float(cache_k_scale_cur[8]));
|
||||
static_cast<T>((input_left * cos_tmp - input_right * sin_tmp) *
|
||||
float(cache_k_scale_cur[8]));
|
||||
bias_vec2[1] =
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) * float(cache_k_scale_cur[9]));
|
||||
static_cast<T>((input_right * cos_tmp + input_left * sin_tmp) *
|
||||
float(cache_k_scale_cur[9]));
|
||||
} else {
|
||||
bias_vec2[0] = static_cast<T>(input_left * float(cache_v_scale_cur[8]));
|
||||
bias_vec2[1] = static_cast<T>(input_right * float(cache_v_scale_cur[9]));
|
||||
bias_vec2[1] =
|
||||
static_cast<T>(input_right * float(cache_v_scale_cur[9]));
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
@@ -1623,7 +1676,6 @@ __global__ void append_decode_cache_int8_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
|
||||
__global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -1633,7 +1685,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
@@ -1696,7 +1748,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
|
||||
@@ -1779,7 +1832,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -1934,7 +1988,7 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -1943,8 +1997,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
const float* __restrict__ sin_emb,
|
||||
const float* __restrict__ qkv_out_scales, // [num_head + 2 *
|
||||
// kv_num_heads, dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const T* __restrict__ cache_k_scales,
|
||||
const T* __restrict__ cache_v_scales,
|
||||
const int max_seq_len,
|
||||
@@ -2014,11 +2068,11 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
@@ -2125,7 +2179,8 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
|
||||
T scale;
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -2321,7 +2376,6 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
|
||||
__global__ void append_decode_cache_int4_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
@@ -2331,7 +2385,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -2373,17 +2427,18 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
const T* qkv_now =
|
||||
quant_qkv + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
T* qkv_out_now =
|
||||
qkv_out + start_token_idx * hidden_size + head_idx * HeadDim;
|
||||
|
||||
uint32_t emb_offset = write_seq_id * half_head_size;
|
||||
emb_offset += rope_3d ? bid * max_seq_len * HeadDim : 0;
|
||||
apply_rope<T, VecSize, HeadDim, 32>(
|
||||
qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
apply_rope<T, VecSize, HeadDim, 32>(qkv_now,
|
||||
cos_emb + emb_offset,
|
||||
sin_emb + emb_offset,
|
||||
qkv_out_now,
|
||||
lane_id);
|
||||
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
@@ -2443,7 +2498,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -2603,7 +2659,7 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -2612,8 +2668,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
const float* __restrict__ sin_emb,
|
||||
const float* __restrict__ qkv_out_scales, // [num_head + 2 *
|
||||
// kv_num_heads, dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const T* __restrict__ cache_k_scale,
|
||||
const T* __restrict__ cache_v_scale,
|
||||
const T* __restrict__ cache_k_zero_points,
|
||||
@@ -2674,7 +2730,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
Load<float, VecSize>(&qkv_out_scales[bias_idx], &out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
@@ -2763,7 +2820,8 @@ __global__ void append_decode_cache_int4_rope_kernel(
|
||||
&out_scale_vec2);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -2934,7 +2992,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -2999,7 +3057,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
@@ -3082,7 +3141,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[right_bias_idx], &right_src_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[right_bias_idx + 8], &right_src_vec2);
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -3114,7 +3174,6 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
right_out_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
|
||||
|
||||
input_left = static_cast<float>(left_src_vec2[i]);
|
||||
input_right = static_cast<float>(right_src_vec2[i]);
|
||||
cos_tmp = cos_emb_vec2[i];
|
||||
@@ -3307,7 +3366,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
@@ -3316,8 +3375,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
const float* __restrict__ sin_emb,
|
||||
const float* __restrict__ qkv_out_scales, // [num_head + 2 *
|
||||
// kv_num_heads, dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const T* __restrict__ qkv_biases, // [num_head + 2 * kv_num_heads,
|
||||
// dim_head]
|
||||
const T* __restrict__ cache_k_scale,
|
||||
const T* __restrict__ cache_v_scale,
|
||||
const T* __restrict__ cache_k_zero_points,
|
||||
@@ -3387,7 +3446,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
@@ -3498,7 +3558,8 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
&right_out_scale_vec2);
|
||||
|
||||
const uint32_t emb_idx = write_seq_id * HeadDim + head_bias;
|
||||
uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
uint32_t new_emb_idx =
|
||||
rope_3d ? emb_idx + bid * max_seq_len * HeadDim * 2 : emb_idx;
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&cos_emb[new_emb_idx + 8], &cos_emb_vec2);
|
||||
Load<float, HALF_K_VEC_SIZE>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
@@ -3536,7 +3597,6 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
|
||||
right_bias_vec1[i] =
|
||||
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
|
||||
|
||||
|
||||
input_left = static_cast<float>(left_src_vec2[i]);
|
||||
input_right = static_cast<float>(right_src_vec2[i]);
|
||||
cos_tmp = cos_emb_vec2[i];
|
||||
|
||||
@@ -17,30 +17,30 @@
|
||||
|
||||
template <typename T, typename QKV_TYPE>
|
||||
void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
T* key_cache,
|
||||
T* value_cache,
|
||||
T* qkv_out,
|
||||
const int* block_tables,
|
||||
const int* cu_seqlens_q,
|
||||
const int* seq_lens,
|
||||
const int* seq_lens_encoder,
|
||||
const float* cos_emb,
|
||||
const float* sin_emb,
|
||||
const float* qkv_out_scales,
|
||||
const T* qkv_biases,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads,
|
||||
const int dim_head,
|
||||
const int block_size,
|
||||
const int bsz,
|
||||
const cudaStream_t& stream,
|
||||
const bool use_neox_style,
|
||||
const bool rope_3d,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const float rms_norm_eps) {
|
||||
const uint32_t elem_nums =
|
||||
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
|
||||
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
|
||||
@@ -134,47 +134,49 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
if (rotary_dim < dim_head){
|
||||
if (rotary_dim < dim_head) {
|
||||
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}else{
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
rotary_dim,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_T_neox_rope_kernel<T, PackSize>
|
||||
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
<<<grid_size, blocksize, 0, stream>>>(
|
||||
reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
value_cache,
|
||||
qkv_out,
|
||||
block_tables,
|
||||
cu_seqlens_q,
|
||||
seq_lens,
|
||||
seq_lens_encoder,
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
elem_nums,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -225,7 +227,10 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename QKV_TYPE, bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
template <typename T,
|
||||
typename QKV_TYPE,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false>
|
||||
void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
uint8_t* key_cache,
|
||||
uint8_t* value_cache,
|
||||
@@ -306,7 +311,12 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
}
|
||||
} else {
|
||||
if (qkv_out_scales) {
|
||||
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
|
||||
append_decode_cache_int8_rope_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
is_scale_channel_wise,
|
||||
IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const int*>(qkv),
|
||||
key_cache,
|
||||
@@ -331,7 +341,12 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
|
||||
kv_num_heads,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int8_rope_kernel<T, 4, 0, 128, is_scale_channel_wise, IsFP8>
|
||||
append_decode_cache_int8_rope_kernel<T,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
is_scale_channel_wise,
|
||||
IsFP8>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const T*>(qkv),
|
||||
key_cache,
|
||||
@@ -546,11 +561,15 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
use_neox_rotary_style
|
||||
? rotary_embs.get().data<float>() + max_seq_len * dim_head
|
||||
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
|
||||
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
|
||||
if(rotary_dim < dim_head){
|
||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
|
||||
rotary_dim =
|
||||
rotary_embs.get().dims()[rotary_embs.get().dims().size() - 1] * 2;
|
||||
if (rotary_dim < dim_head) {
|
||||
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight ||
|
||||
k_norm_weight || cache_quant_type_str != "none") {
|
||||
PADDLE_THROW(phi::errors::Fatal(
|
||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
|
||||
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, "
|
||||
"qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and "
|
||||
"cache_quant_type_str is 'none'."));
|
||||
}
|
||||
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
|
||||
}
|
||||
@@ -571,8 +590,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
@@ -588,10 +607,16 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) /
|
||||
num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
false,
|
||||
true,
|
||||
true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
@@ -603,8 +628,48 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
(cache_v_scale.get().data<T>()))),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if ((cache_quant_type_str == "cache_fp8")) {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) /
|
||||
num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
false,
|
||||
true,
|
||||
false>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
(cache_v_scale.get().data<T>()))),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
@@ -618,7 +683,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
|
||||
"append_decode_cache_rope_qk_norm just supports cache_quant_type "
|
||||
"none/block_wise_fp8/cache_fp8");
|
||||
}
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
@@ -635,8 +701,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
@@ -650,11 +716,77 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "cache_int8") {
|
||||
bool is_scale_channel_wise = false;
|
||||
if (cache_k_scale && cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
|
||||
if (cache_k_scale &&
|
||||
cache_k_scale.get().dims()[0] == dim_head * kv_num_heads) {
|
||||
is_scale_channel_wise = true;
|
||||
}
|
||||
if (is_scale_channel_wise) {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
}
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
@@ -667,8 +799,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -685,77 +817,17 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
}
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
append_decode_cache_int8_rope<DataType_, QKV_TYPE, false, true>(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
dim_head,
|
||||
block_size,
|
||||
bsz,
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) /
|
||||
num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
|
||||
4,
|
||||
0,
|
||||
128,
|
||||
false,
|
||||
true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
@@ -767,8 +839,10 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
|
||||
(cache_v_scale.get().data<T>()))),
|
||||
nullptr,
|
||||
nullptr,
|
||||
max_seq_len,
|
||||
@@ -794,8 +868,8 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
sin_emb,
|
||||
qkv_out_scales ? qkv_out_scales.get().data<float>() : nullptr,
|
||||
qkv_biases ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T*>(qkv_biases.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_scale ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
@@ -803,11 +877,11 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
const_cast<T*>(cache_v_scale.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_k_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T*>(cache_k_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
cache_v_zp ? reinterpret_cast<DataType_*>(
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
const_cast<T*>(cache_v_zp.get().data<T>()))
|
||||
: nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
@@ -826,7 +900,6 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template void DecoderWriteCacheWithRoPEKernel<paddle::bfloat16, int>(
|
||||
const AppendAttnMetaData& meta_data,
|
||||
const paddle::Tensor&
|
||||
|
||||
Reference in New Issue
Block a user