mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 00:33:03 +08:00
supports dynamic Cfp8 (#3767)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
* supports dynamic Cfp8 * add unittest
This commit is contained in:
@@ -140,8 +140,8 @@ void AppendAttentionKernel(
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_mask,
|
||||
cache_k_dequant_scales,
|
||||
cache_v_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
|
||||
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
|
||||
cache_k_zp,
|
||||
cache_v_zp,
|
||||
out_linear_shifts,
|
||||
|
@@ -32,14 +32,15 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__global__ void multi_query_append_attention_c8_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -91,9 +92,10 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
@@ -114,6 +116,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
@@ -201,6 +204,13 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale = nullptr;
|
||||
T* v_smem_scale = nullptr;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale = k_smem_scale + num_frags_z * 16;
|
||||
}
|
||||
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
@@ -282,10 +292,22 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale,
|
||||
cache_k_scale_reg,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -318,6 +340,7 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const int ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -336,6 +359,18 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
commit_group();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale,
|
||||
cache_v_scale_reg,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
ori_kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
@@ -346,7 +381,9 @@ __global__ void multi_query_append_attention_c8_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -463,14 +500,15 @@ template <typename T,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool is_scale_channel_wise=false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
|
||||
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
|
||||
// head_dim]
|
||||
CacheT *__restrict__ cache_v,
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
|
||||
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
|
||||
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
|
||||
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
|
||||
const int *__restrict__ seq_lens,
|
||||
@@ -522,9 +560,10 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
if (q_len <= 0) {
|
||||
return;
|
||||
}
|
||||
T cache_k_scale_reg[num_frags_y * 4];
|
||||
T cache_v_scale_reg[num_frags_y * 2];
|
||||
if (is_scale_channel_wise) {
|
||||
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
|
||||
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
|
||||
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
|
||||
for (int i = 0; i < num_frags_y; ++i) {
|
||||
@@ -545,6 +584,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
|
||||
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
|
||||
}
|
||||
}
|
||||
const uint32_t q_end =
|
||||
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
|
||||
uint32_t kv_len = seq_lens_kv[batch_id];
|
||||
@@ -634,6 +674,13 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
|
||||
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
|
||||
T* k_smem_scale = nullptr;
|
||||
T* v_smem_scale = nullptr;
|
||||
if constexpr (IsDynamicC8) {
|
||||
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
|
||||
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
|
||||
}
|
||||
|
||||
const uint32_t num_iterations = div_up(
|
||||
CAUSAL
|
||||
@@ -716,11 +763,23 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
commit_group();
|
||||
#pragma unroll 1
|
||||
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
k_smem_scale,
|
||||
cache_k_scale_reg,
|
||||
block_table_now,
|
||||
cache_k_scale,
|
||||
kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
// s = qk
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
|
||||
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
|
||||
&qo_smem,
|
||||
&q_smem_offset_r,
|
||||
&k_smem,
|
||||
@@ -753,6 +812,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
s_frag, o_frag, m_frag, d_frag);
|
||||
__syncthreads();
|
||||
|
||||
const uint32_t ori_kv_idx_base = kv_idx_base;
|
||||
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
|
||||
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
|
||||
NUM_WARPS,
|
||||
@@ -771,6 +831,18 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
chunk_end,
|
||||
const_k_offset);
|
||||
commit_group();
|
||||
if constexpr (IsDynamicC8) {
|
||||
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
|
||||
v_smem_scale,
|
||||
cache_v_scale_reg,
|
||||
block_table_now,
|
||||
cache_v_scale,
|
||||
ori_kv_idx_base,
|
||||
kv_num_heads,
|
||||
kv_head_idx,
|
||||
chunk_end
|
||||
);
|
||||
}
|
||||
wait_group<1>();
|
||||
__syncthreads();
|
||||
|
||||
@@ -781,7 +853,9 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
||||
BLOCK_SIZE,
|
||||
T,
|
||||
CacheT,
|
||||
is_scale_channel_wise, IsFP8>(
|
||||
is_scale_channel_wise,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
|
||||
__syncthreads();
|
||||
|
||||
@@ -895,7 +969,8 @@ template <typename T,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename OutT = T,
|
||||
bool ENABLE_PREFILL = true,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
void MultiQueryAppendC8Attention(
|
||||
const AppendAttnMetaData &meta_data,
|
||||
const paddle::Tensor &qkv,
|
||||
@@ -953,7 +1028,8 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
|
||||
constexpr uint32_t smem_size =
|
||||
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
num_frags_z * 16 * sizeof(T) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -970,7 +1046,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -988,7 +1066,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1022,7 +1102,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_kernel<NV_TYPE,
|
||||
@@ -1040,7 +1122,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1218,7 +1302,8 @@ void MultiQueryAppendC8Attention(
|
||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
|
||||
constexpr uint32_t smem_size =
|
||||
num_frags_x * 16 * HEAD_DIM * sizeof(T) +
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
|
||||
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
|
||||
NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2;
|
||||
auto split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
uint8_t,
|
||||
@@ -1235,7 +1320,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
split_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1253,7 +1340,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(split_kv_kernel,
|
||||
@@ -1295,7 +1384,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
false, IsFP8>;
|
||||
false,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
if (is_scale_channel_wise) {
|
||||
nosplit_kv_kernel =
|
||||
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
|
||||
@@ -1313,7 +1404,9 @@ void MultiQueryAppendC8Attention(
|
||||
num_frags_y,
|
||||
OUT_NV_TYPE,
|
||||
ENABLE_PREFILL,
|
||||
true, IsFP8>;
|
||||
true,
|
||||
IsFP8,
|
||||
IsDynamicC8>;
|
||||
}
|
||||
if (smem_size >= 48 * 1024) {
|
||||
cudaFuncSetAttribute(nosplit_kv_kernel,
|
||||
@@ -1546,6 +1639,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out) {
|
||||
const auto token_num = meta_data.token_nums;
|
||||
@@ -1554,6 +1648,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const auto num_heads = meta_data.q_num_heads;
|
||||
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
|
||||
const auto head_dim = meta_data.head_dims;
|
||||
bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8";
|
||||
|
||||
DISPATCH_CAUSAL(
|
||||
causal,
|
||||
@@ -1572,6 +1667,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
BLOCK_SIZE,
|
||||
{DISPATCH_BLOCKSHAPE_Q(
|
||||
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
|
||||
DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, {
|
||||
MultiQueryAppendC8Attention<T,
|
||||
GROUP_SIZE,
|
||||
HEAD_DIM,
|
||||
@@ -1580,7 +1676,9 @@ void CascadeAppendAttentionC8Kernel(
|
||||
BLOCK_SHAPE_Q,
|
||||
NUM_WARP_Q,
|
||||
OutT,
|
||||
ENABLE_PREFILL, IsFP8>(
|
||||
ENABLE_PREFILL,
|
||||
IsFP8,
|
||||
IsDynamicC8>(
|
||||
meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
@@ -1610,5 +1708,5 @@ void CascadeAppendAttentionC8Kernel(
|
||||
is_decoder,
|
||||
stream,
|
||||
out);
|
||||
})})})})})})
|
||||
})})})})})})})
|
||||
}
|
||||
|
@@ -384,6 +384,113 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_k_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
k_smem_scale[tid] = cache_k_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
k_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_v_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end
|
||||
) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
v_smem_scale[tid] = cache_v_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
v_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <SharedMemFillMode fill_mode,
|
||||
uint32_t num_warps,
|
||||
uint32_t block_size,
|
||||
@@ -816,7 +923,8 @@ template <uint32_t num_frags_x,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8=false>
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
uint32_t* q_smem_offset_r,
|
||||
smem_t* k_smem,
|
||||
@@ -860,6 +968,7 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
const int scale_col = (ky * 2 + fy) * 4;
|
||||
b_frag_dq_T[0] *= cache_k_scale[scale_col];
|
||||
@@ -876,6 +985,12 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
|
||||
if (ky == 0 && fy == 0) {
|
||||
@@ -1093,7 +1208,9 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1135,6 +1252,7 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
@@ -1146,6 +1264,17 @@ __device__ __forceinline__ void compute_sfm_v_c8(
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
mma_sync_m16n16k16_row_col_f16f16f32<T>(
|
||||
@@ -1171,7 +1300,9 @@ template <uint32_t num_frags_x,
|
||||
uint32_t block_size,
|
||||
typename T,
|
||||
typename CacheT,
|
||||
bool is_scale_channel_wise = false, bool IsFP8=false>
|
||||
bool is_scale_channel_wise = false,
|
||||
bool IsFP8 = false,
|
||||
bool IsDynamicC8 = false>
|
||||
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
smem_t* v_smem,
|
||||
uint32_t* v_smem_offset_r,
|
||||
@@ -1215,6 +1346,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
|
||||
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
|
||||
// scale zp
|
||||
if constexpr (!IsDynamicC8) {
|
||||
if constexpr (is_scale_channel_wise) {
|
||||
#pragma unroll
|
||||
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
|
||||
@@ -1226,6 +1358,17 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
|
||||
b_frag_dq_T[b_i] *= cache_v_scale[0];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const int scale_col = (kz * 2 + fz) * 4;
|
||||
b_frag_dq_T[0] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
|
||||
b_frag_dq_T[4] *= cache_v_scale[scale_col];
|
||||
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
|
||||
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
|
||||
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
|
||||
mma_sync_m16n16k16_row_col_f16f16f32<T>(
|
||||
|
@@ -103,6 +103,7 @@ void CascadeAppendAttentionC8Kernel(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -264,9 +265,10 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
|
||||
qkv,
|
||||
cache_k,
|
||||
@@ -299,6 +301,7 @@ void CascadeAppendAttentionKernel(
|
||||
causal,
|
||||
is_decoder,
|
||||
enable_prefill,
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
out);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
|
@@ -674,6 +674,294 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=true>
|
||||
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
// head_size]
|
||||
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
|
||||
// block_size, head_size // 2]
|
||||
T* __restrict__ qkv_out,
|
||||
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
|
||||
const int* __restrict__ batch_id_per_token, // [num_tokens]
|
||||
const int* __restrict__ cu_seqlens_q,
|
||||
const int* __restrict__ seq_lens, // [bsz]
|
||||
const int* __restrict__ seq_lens_encoder, // [bsz]
|
||||
const float* __restrict__ cos_emb,
|
||||
const float* __restrict__ sin_emb,
|
||||
T* __restrict__ cache_k_scale,
|
||||
T* __restrict__ cache_v_scale,
|
||||
const float* q_norm_weight,
|
||||
const float* k_norm_weight,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int block_size,
|
||||
const float max_bound,
|
||||
const float min_bound,
|
||||
const int kv_num_heads,
|
||||
const bool rope_3d,
|
||||
const float rms_norm_eps) {
|
||||
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
|
||||
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
|
||||
constexpr int NUM_WARPS = 4;
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane_id = tid % 32;
|
||||
const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid;
|
||||
int q_head_idx, k_head_idx, v_idx;
|
||||
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
|
||||
constexpr int half_head_size = HeadDim / 2;
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
if (seq_lens_encoder[bid] > 0) return;
|
||||
const int write_seq_id = seq_lens[bid];
|
||||
if (write_seq_id == 0) return;
|
||||
const int* block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + bid * max_blocks_per_seq;
|
||||
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
|
||||
const int block_offset = write_seq_id % block_size;
|
||||
|
||||
int cache_offset;
|
||||
if (head_idx < num_heads) {
|
||||
cache_offset = 0;
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset;
|
||||
}
|
||||
T *cache_k_scale_now = cache_k_scale + cache_offset;
|
||||
T *cache_v_scale_now = cache_v_scale + cache_offset;
|
||||
|
||||
float thread_m2 = 0.0f;
|
||||
float warp_m2 = 0.0f;
|
||||
|
||||
if (head_idx < num_heads) {
|
||||
// q
|
||||
using LoadT = AlignedVector<T, VecSize>;
|
||||
using LoadBiasT = AlignedVector<T, VecSize>;
|
||||
using LoadOutScaleT = AlignedVector<float, VecSize>;
|
||||
constexpr int HalfVecSize = VecSize / 2;
|
||||
using LoadEmbT = AlignedVector<float, HalfVecSize>;
|
||||
|
||||
LoadT src_vec;
|
||||
LoadBiasT out_vec;
|
||||
LoadEmbT cos_emb_vec;
|
||||
LoadEmbT sin_emb_vec;
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
|
||||
#pragma unroll
|
||||
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
|
||||
head_bias += 32 * VecSize) {
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
|
||||
// q rope
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
|
||||
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HalfVecSize; i++) {
|
||||
// dequant + add_bias + rope
|
||||
float input_left = static_cast<float>(src_vec[2 * i]);
|
||||
float input_right = static_cast<float>(src_vec[2 * i + 1]);
|
||||
|
||||
const float cos_tmp = cos_emb_vec[i];
|
||||
const float sin_tmp = sin_emb_vec[i];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec[2 * i] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec[2 * i + 1] =
|
||||
static_cast<T>(tmp2);
|
||||
}
|
||||
// qk norm
|
||||
if (q_norm_weight) {
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
LoadOutScaleT q_norm_vec;
|
||||
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VecSize; i++) {
|
||||
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * q_norm_vec[i]);
|
||||
}
|
||||
}
|
||||
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
|
||||
}
|
||||
} else if (head_idx < num_heads + 2 * kv_num_heads) {
|
||||
// k
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads;
|
||||
if (block_offset == 0) {
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx =
|
||||
(block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim;
|
||||
block_i < block_size;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&key_cache[tgt_idx + block_i * HeadDim]);
|
||||
}
|
||||
} else {
|
||||
const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE;
|
||||
const int num_token_each_time = 32 / num_vecs_per_head_dim;
|
||||
const uint32_t tgt_idx =
|
||||
(block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size +
|
||||
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
|
||||
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
|
||||
block_i += num_token_each_time) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
constexpr int K_VEC_SIZE = 4;
|
||||
constexpr int HALF_K_VEC_SIZE = 2;
|
||||
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
|
||||
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
|
||||
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
|
||||
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
|
||||
using LoadEmbT = AlignedVector<float, 1>;
|
||||
LoadKVResT cache_vec;
|
||||
LoadT src_vec1, src_vec2;
|
||||
LoadBiasT out_vec1, out_vec2;
|
||||
LoadEmbT cos_emb_vec1, cos_emb_vec2;
|
||||
LoadEmbT sin_emb_vec1, sin_emb_vec2;
|
||||
|
||||
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
|
||||
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
|
||||
const int bias_idx = head_idx * HeadDim + head_bias;
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
|
||||
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
|
||||
T scale = T(1.0f);
|
||||
const int k_head_idx = head_idx - num_heads;
|
||||
const int v_head_idx = head_idx - num_heads - kv_num_heads;
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
|
||||
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
|
||||
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
|
||||
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
|
||||
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
|
||||
}
|
||||
|
||||
float input_left = static_cast<float>(src_vec1[0]);
|
||||
float input_right = static_cast<float>(src_vec1[1]);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec1[0];
|
||||
float sin_tmp = sin_emb_vec1[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec1[0] =
|
||||
static_cast<T>(tmp1);
|
||||
out_vec1[1] =
|
||||
static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec1[0] = src_vec1[0];
|
||||
out_vec1[1] = src_vec1[1];
|
||||
}
|
||||
|
||||
// rope
|
||||
input_left = static_cast<float>(src_vec2[0]);
|
||||
input_right = static_cast<float>(src_vec2[1]);
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
float cos_tmp = cos_emb_vec2[0];
|
||||
float sin_tmp = sin_emb_vec2[0];
|
||||
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
|
||||
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
|
||||
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
|
||||
out_vec2[0] = static_cast<T>(tmp1);
|
||||
out_vec2[1] = static_cast<T>(tmp2);
|
||||
} else {
|
||||
out_vec2[0] = src_vec2[0];
|
||||
out_vec2[1] = src_vec2[1];
|
||||
}
|
||||
if (k_norm_weight) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
LoadOutScaleT k_norm_vec1, k_norm_vec2;
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
|
||||
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
|
||||
// qk norm
|
||||
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
|
||||
float row_variance =
|
||||
max(warp_m2 / HeadDim, 0.0f);
|
||||
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
|
||||
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
out_vec1[i] = static_cast<T>(static_cast<float>(out_vec1[i]) * row_inv_var * k_norm_vec1[i]);
|
||||
out_vec2[i] = static_cast<T>(static_cast<float>(out_vec2[i]) * row_inv_var * k_norm_vec2[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
// reduce max, 1 head per warp
|
||||
T local_max = -INFINITY;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
local_max = __hmax(local_max, __habs(out_vec1[i]));
|
||||
local_max = __hmax(local_max, __habs(out_vec2[i]));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int m_offset = 16; m_offset > 1; m_offset /= 2) {
|
||||
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
|
||||
}
|
||||
|
||||
scale = __hdiv(448, local_max);
|
||||
|
||||
if (lane_id == 0) {
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
cache_k_scale_now[0] = __hdiv(1, scale);
|
||||
} else {
|
||||
cache_v_scale_now[0] = __hdiv(1, scale);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
|
||||
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec1[i], max_bound, min_bound);
|
||||
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec2[i], max_bound, min_bound);
|
||||
}
|
||||
if (head_idx < num_heads + kv_num_heads) {
|
||||
const int start_block_16 =
|
||||
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
|
||||
const uint32_t tgt_cache_idx =
|
||||
block_idx * kv_num_heads * block_size * HeadDim +
|
||||
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
|
||||
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
|
||||
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
|
||||
} else {
|
||||
const uint32_t base_tgt_cache_idx =
|
||||
block_idx * kv_num_heads * HeadDim * block_size +
|
||||
kv_head_idx * HeadDim * block_size +
|
||||
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
|
||||
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
|
||||
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
|
||||
block_offset % 8 / 2 * 4 // per 4
|
||||
+ block_offset % 16 / 8 * 2 // per 2
|
||||
+ block_offset % 2; // per 1
|
||||
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
|
||||
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
|
||||
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
|
||||
value_cache[tgt_cache_idx1] = cache_vec[0];
|
||||
value_cache[tgt_cache_idx2] = cache_vec[1];
|
||||
value_cache[tgt_cache_idx3] = cache_vec[2];
|
||||
value_cache[tgt_cache_idx4] = cache_vec[3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
|
||||
__global__ void append_decode_cache_int8_rope_kernel(
|
||||
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
|
||||
|
@@ -553,9 +553,40 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
|
||||
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
q_norm_weight.get().data<float>(),
|
||||
k_norm_weight.get().data<float>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else {
|
||||
PD_THROW(
|
||||
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
|
||||
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
|
||||
}
|
||||
} else {
|
||||
if (cache_quant_type_str == "none") {
|
||||
@@ -686,6 +717,37 @@ void DecoderWriteCacheWithRoPEKernel(
|
||||
stream,
|
||||
use_neox_rotary_style,
|
||||
rope_3d);
|
||||
} else if (cache_quant_type_str == "block_wise_fp8") {
|
||||
constexpr int num_warps = 4;
|
||||
const int all_warps =
|
||||
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
|
||||
dim3 grids(bsz, all_warps / num_warps);
|
||||
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
|
||||
<<<grids, num_warps * 32, 0, stream>>>(
|
||||
reinterpret_cast<const DataType_*>(qkv_ptr),
|
||||
key_cache_out->data<uint8_t>(),
|
||||
value_cache_out->data<uint8_t>(),
|
||||
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
|
||||
block_tables.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
seq_lens.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
cos_emb,
|
||||
sin_emb,
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
|
||||
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
|
||||
nullptr,
|
||||
nullptr,
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
block_size,
|
||||
127.0f,
|
||||
-127.0f,
|
||||
kv_num_heads,
|
||||
rope_3d,
|
||||
rms_norm_eps);
|
||||
} else if (cache_quant_type_str == "cache_int4_zp") {
|
||||
append_decode_cache_int4_rope(
|
||||
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),
|
||||
|
@@ -1232,6 +1232,411 @@ __global__ void append_write_cache_kv_c8_qkv(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t HEAD_DIM,
|
||||
uint32_t BLOCK_SIZE,
|
||||
uint32_t NUM_WARPS,
|
||||
bool is_need_kv_quant,
|
||||
bool IsFP8 = true>
|
||||
__global__ void append_write_cache_kv_c8_qkv_dynamic(
|
||||
uint8_t *__restrict__ cache_k,
|
||||
uint8_t *__restrict__ cache_v,
|
||||
const T *__restrict__ qkv_input,
|
||||
T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size]
|
||||
T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size]
|
||||
const int *__restrict__ batch_ids,
|
||||
const int *__restrict__ tile_ids,
|
||||
const int *__restrict__ seq_lens_this_time,
|
||||
const int *__restrict__ seq_lens_decoder,
|
||||
const int *__restrict__ batch_id_per_token,
|
||||
const int *__restrict__ cu_seqlens_q,
|
||||
const int *__restrict__ block_tables,
|
||||
const int max_seq_len,
|
||||
const int max_blocks_per_seq,
|
||||
const int num_heads,
|
||||
const int kv_num_heads) {
|
||||
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
|
||||
constexpr uint32_t pad_len = BLOCK_SIZE;
|
||||
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
|
||||
const T cache_k_scale = cache_k_scales[kv_head_idx];
|
||||
const T cache_v_scale = cache_v_scales[kv_head_idx];
|
||||
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
|
||||
const uint32_t batch_id = batch_ids[btid];
|
||||
const uint32_t tile_id = tile_ids[btid];
|
||||
const uint32_t seq_len_this_time = seq_lens_this_time[batch_id];
|
||||
if (seq_len_this_time <= 0) {
|
||||
return;
|
||||
}
|
||||
const int *block_table_now = nullptr;
|
||||
|
||||
block_table_now = block_tables + batch_id * max_blocks_per_seq;
|
||||
|
||||
const uint32_t num_rows_per_block =
|
||||
NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE
|
||||
const uint32_t start_len = seq_lens_decoder[batch_id];
|
||||
const uint32_t bf_pad_len = start_len % pad_len;
|
||||
const uint32_t start_len_pad = start_len - bf_pad_len;
|
||||
const uint32_t end_len = start_len + seq_len_this_time;
|
||||
|
||||
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
|
||||
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
|
||||
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
|
||||
|
||||
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
|
||||
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
|
||||
const uint32_t kv_h_stride = HEAD_DIM;
|
||||
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
|
||||
__shared__ T v_scale_smem[BLOCK_SIZE];
|
||||
if (tile_start >= start_len) {
|
||||
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
|
||||
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
|
||||
// pad zero for this kv_head_idx for this block
|
||||
LoadPadKVT pad_cache_vec;
|
||||
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
|
||||
// reset k
|
||||
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
|
||||
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
|
||||
uint32_t tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
|
||||
tid % num_vecs_per_head_k * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_k;
|
||||
block_i < BLOCK_SIZE;
|
||||
block_i += num_token_each_time_k) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
|
||||
&cache_k[tgt_idx + block_i * HEAD_DIM]);
|
||||
}
|
||||
|
||||
// reset v
|
||||
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
|
||||
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
|
||||
tgt_idx =
|
||||
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
|
||||
tid % num_vecs_per_head_v * KV_VEC_SIZE;
|
||||
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
|
||||
block_i += num_token_each_time_v) {
|
||||
Store<uint8_t, KV_VEC_SIZE>(
|
||||
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
|
||||
}
|
||||
}
|
||||
smem_t k_smem(k_smem_ori);
|
||||
smem_t v_smem(v_smem_ori);
|
||||
|
||||
uint32_t kv_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp
|
||||
|
||||
/*
|
||||
0 | 1
|
||||
2 | 3
|
||||
*/
|
||||
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
|
||||
|
||||
constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS;
|
||||
/*
|
||||
0 | 2
|
||||
1 | 3
|
||||
*/
|
||||
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
|
||||
tid % 16, wid * num_frags_v * 2 + tid / 16);
|
||||
|
||||
// load kv gmem to smem
|
||||
const uint32_t real_start_token_idx = start_token_idx - bf_pad_len +
|
||||
tile_id * num_rows_per_block +
|
||||
wid * num_frags_z * 16 + tid / 8;
|
||||
uint32_t k_read_idx = real_start_token_idx * kv_batch_stride +
|
||||
(num_heads + kv_head_idx) * kv_h_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
uint32_t v_read_idx = real_start_token_idx * kv_batch_stride +
|
||||
(num_heads + kv_num_heads + kv_head_idx) * kv_h_stride +
|
||||
tid % 8 * num_elems_per_128b<T>();
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; ++j) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y / 4;
|
||||
++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b<T>())
|
||||
if (chunk_start >= start_len && chunk_start < end_len) {
|
||||
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len);
|
||||
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
|
||||
kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len);
|
||||
}
|
||||
kv_smem_offset_w =
|
||||
k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy);
|
||||
k_read_idx += 8 * num_elems_per_128b<T>();
|
||||
v_read_idx += 8 * num_elems_per_128b<T>();
|
||||
}
|
||||
kv_smem_offset_w =
|
||||
k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) -
|
||||
2 * num_frags_y;
|
||||
chunk_start += 4;
|
||||
k_read_idx +=
|
||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
||||
v_read_idx +=
|
||||
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
|
||||
}
|
||||
}
|
||||
commit_group();
|
||||
wait_group<0>();
|
||||
__syncthreads();
|
||||
|
||||
// reduce scale
|
||||
// 16 rows per warp
|
||||
uint32_t kv_reduce_frag[4];
|
||||
T *kv_reduce_frag_T = reinterpret_cast<T*>(kv_reduce_frag);
|
||||
|
||||
T k_local_max_value[num_frags_z * 2];
|
||||
T v_local_max_value[num_frags_z * 2];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
||||
k_local_max_value[i] = -INFINITY;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < num_frags_z * 2; i++) {
|
||||
v_local_max_value[i] = -INFINITY;
|
||||
}
|
||||
const int num_kv_heads = gridDim.z;
|
||||
const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE;
|
||||
T *cache_k_scale_now = cache_k_scales + scale_offset;
|
||||
T *cache_v_scale_now = cache_v_scales + scale_offset;
|
||||
// k scale
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
// reduce per thread, 4 threads each row
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]);
|
||||
}
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
// reduce per row
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T local_max_value = __habs(k_local_max_value[fz * 2 + i]);
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
||||
// used for quant
|
||||
k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
||||
}
|
||||
// store
|
||||
if (tid % 4 == 0) {
|
||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
||||
// used for dequant
|
||||
if (tile_start + offset_now >= start_len) {
|
||||
if (tile_start + offset_now < end_len) {
|
||||
cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]);
|
||||
} else {
|
||||
cache_k_scale_now[offset_now] = 0;
|
||||
}
|
||||
}
|
||||
if (tile_start + offset_now + 8 >= start_len) {
|
||||
if (tile_start + offset_now + 8 < end_len) {
|
||||
cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]);
|
||||
} else {
|
||||
cache_k_scale_now[offset_now + 8] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
||||
}
|
||||
// v scale
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
// reduce per thread, 4 threads each row
|
||||
v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]);
|
||||
}
|
||||
k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
// reduce per row
|
||||
for (int i = 0; i < 2; i++) {
|
||||
T local_max_value = __habs(v_local_max_value[fz * 2 + i]);
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
|
||||
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
|
||||
v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
|
||||
}
|
||||
// store
|
||||
if (tid % 4 == 0) {
|
||||
const int offset_now = wid * num_frags_z * 16 + tid / 4;
|
||||
// used for dequant
|
||||
if (tile_start + offset_now >= start_len) {
|
||||
if (tile_start + offset_now < end_len) {
|
||||
cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]);
|
||||
v_scale_smem[offset_now] = v_local_max_value[fz * 2];
|
||||
} else {
|
||||
cache_v_scale_now[offset_now] = 0;
|
||||
v_scale_smem[offset_now] = 0;
|
||||
}
|
||||
}
|
||||
if (tile_start + offset_now + 8 >= start_len) {
|
||||
if (tile_start + offset_now + 8 < end_len) {
|
||||
cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]);
|
||||
v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1];
|
||||
} else {
|
||||
cache_v_scale_now[offset_now + 8] = 0;
|
||||
v_scale_smem[offset_now + 8] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// mask, quant, store
|
||||
using LoadKVT = AlignedVector<uint8_t, 4>;
|
||||
LoadKVT cache_vec1;
|
||||
LoadKVT cache_vec2;
|
||||
|
||||
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
|
||||
uint32_t kv_frag[4];
|
||||
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
|
||||
const uint32_t write_b_stride = HEAD_DIM;
|
||||
const uint32_t write_d_stride = BLOCK_SIZE;
|
||||
uint32_t k_write_idx = block_id * write_n_stride +
|
||||
kv_head_idx * write_h_stride +
|
||||
(wid * num_frags_z * 16 + tid / 4) * write_b_stride +
|
||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
|
||||
uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
|
||||
uint32_t k_write_idx_now = k_write_idx_now_z +
|
||||
fy % 2 * 8 * write_b_stride +
|
||||
fy / 2 * 32; // + fy % 2 * 16;
|
||||
// load
|
||||
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
|
||||
// quant
|
||||
T *k_frag_T = reinterpret_cast<T *>(kv_frag);
|
||||
if (bf_pad_len != 0) {
|
||||
Load<uint8_t, 4>(cache_k + k_write_idx_now, &cache_vec1);
|
||||
Load<uint8_t, 4>(cache_k + k_write_idx_now + 16, &cache_vec2);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
||||
uint8_t uint_quant_value;
|
||||
if (chunk_start_k + (v_id / 4) * 8 >= start_len &&
|
||||
chunk_start_k + (v_id / 4) * 8 < end_len) {
|
||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f);
|
||||
} else {
|
||||
uint_quant_value = 0;
|
||||
}
|
||||
if (bf_pad_len != 0) {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] |= uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
||||
}
|
||||
} else {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] = uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id - 4] = uint_quant_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<uint8_t, 4>(cache_vec1, cache_k + k_write_idx_now);
|
||||
Store<uint8_t, 4>(cache_vec2, cache_k + k_write_idx_now + 16);
|
||||
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
|
||||
}
|
||||
k_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) -
|
||||
2 * num_frags_y;
|
||||
chunk_start_k += 16;
|
||||
}
|
||||
|
||||
uint32_t chunk_start_v = tile_start + tid % 4 * 2;
|
||||
uint32_t v_write_idx = block_id * write_n_stride +
|
||||
kv_head_idx * write_h_stride +
|
||||
(wid * num_frags_v * 16 + tid / 4) * write_d_stride +
|
||||
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
|
||||
const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS;
|
||||
T v_scales[num_frags_z_v * 4];
|
||||
for (int v_i = 0; v_i < num_frags_z_v; v_i++) {
|
||||
const int offset = v_i * 16;
|
||||
const int t_offset = tid % 4 * 2;
|
||||
v_scales[v_i * 4] = v_scale_smem[offset + t_offset];
|
||||
v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1];
|
||||
v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8];
|
||||
v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t fy = 0; fy < num_frags_v; ++fy) {
|
||||
uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride;
|
||||
#pragma unroll
|
||||
for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) {
|
||||
uint32_t v_write_idx_now = v_write_idx_now_v +
|
||||
fz % 2 * 8 * write_d_stride +
|
||||
fz / 2 * 32; // + fz % 2 * 16;
|
||||
// load
|
||||
v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag);
|
||||
// quant
|
||||
T *v_frag_T = reinterpret_cast<T *>(kv_frag);
|
||||
if (bf_pad_len != 0) {
|
||||
Load<uint8_t, 4>(cache_v + v_write_idx_now, &cache_vec1);
|
||||
Load<uint8_t, 4>(cache_v + v_write_idx_now + 16, &cache_vec2);
|
||||
}
|
||||
#pragma unroll
|
||||
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
|
||||
uint8_t uint_quant_value;
|
||||
if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len &&
|
||||
chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) {
|
||||
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f);
|
||||
// store now
|
||||
} else {
|
||||
uint_quant_value = 0;
|
||||
}
|
||||
if (bf_pad_len != 0) {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] |= uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] |= uint_quant_value;
|
||||
}
|
||||
} else {
|
||||
if (v_id < 4) {
|
||||
cache_vec1[v_id] = uint_quant_value;
|
||||
} else {
|
||||
cache_vec2[v_id % 4] = uint_quant_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
// store
|
||||
Store<uint8_t, 4>(cache_vec1, cache_v + v_write_idx_now);
|
||||
Store<uint8_t, 4>(cache_vec2, cache_v + v_write_idx_now + 16);
|
||||
chunk_start_v += 16;
|
||||
v_smem_offset_r =
|
||||
k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r);
|
||||
}
|
||||
v_smem_offset_r = k_smem.advance_offset_by_column<2>(
|
||||
v_smem_offset_r, wid * num_frags_v + fy) -
|
||||
16 * num_frags_z_v * num_vecs_per_head;
|
||||
chunk_start_v -= 16 * num_frags_z_v;
|
||||
}
|
||||
}
|
||||
|
||||
// Write Cache KV in Append
|
||||
template <typename T,
|
||||
uint32_t num_frags_y,
|
||||
@@ -2006,10 +2411,11 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
int num_blocks_x_cpu,
|
||||
int max_seq_len,
|
||||
bool is_scale_channel_wise,
|
||||
const bool is_fp8,
|
||||
const std::string& cache_quant_type,
|
||||
cudaStream_t &stream,
|
||||
paddle::Tensor *cache_k_out,
|
||||
paddle::Tensor *cache_v_out) {
|
||||
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
|
||||
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
|
||||
auto num_tokens = meta_data.token_nums;
|
||||
auto num_heads = meta_data.q_num_heads;
|
||||
@@ -2027,6 +2433,7 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
dim3 blocks(32, num_warps);
|
||||
|
||||
const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2;
|
||||
if (cache_quant_type != "block_wise_fp8") {
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
@@ -2034,7 +2441,7 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, false>;
|
||||
if (is_fp8) {
|
||||
if (cache_quant_type == "cache_fp8") {
|
||||
kernel_fn = append_write_cache_kv_c8_qkv<T,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
@@ -2070,6 +2477,33 @@ void CascadeAppendWriteCacheKVC8QKV(
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
} else {
|
||||
auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic<NV_TYPE,
|
||||
num_frags_y,
|
||||
num_frags_z,
|
||||
HEAD_DIM,
|
||||
BLOCK_SIZE,
|
||||
num_warps,
|
||||
true, true>;
|
||||
cudaFuncSetAttribute(
|
||||
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
|
||||
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
|
||||
cache_v_out->data<uint8_t>(),
|
||||
reinterpret_cast<const NV_TYPE*>(qkv.data<T>()),
|
||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_k_scale.data<T>())),
|
||||
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_v_scale.data<T>())),
|
||||
batch_ids.data<int>(),
|
||||
tile_ids_per_batch.data<int>(),
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_decoder.data<int>(),
|
||||
batch_id_per_token.data<int>(),
|
||||
cu_seqlens_q.data<int>(),
|
||||
block_table.data<int>(),
|
||||
max_seq_len,
|
||||
max_blocks_per_seq,
|
||||
num_heads,
|
||||
kv_num_heads);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>
|
||||
|
@@ -167,7 +167,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
|
||||
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
|
||||
DISPATCH_HEAD_DIM(
|
||||
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
|
||||
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
|
||||
@@ -187,7 +187,7 @@ void EncoderWriteCacheWithRopeKernel(
|
||||
num_blocks,
|
||||
max_seq_len,
|
||||
is_scale_channel_wise,
|
||||
cache_quant_type_str == "cache_fp8",
|
||||
cache_quant_type_str,
|
||||
stream,
|
||||
key_cache_out,
|
||||
value_cache_out);
|
||||
|
@@ -1000,7 +1000,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
|
||||
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
|
||||
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
|
||||
meta_data,
|
||||
*const_cast<paddle::Tensor*>(&key_cache),
|
||||
@@ -1018,7 +1018,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
|
||||
kv_num_blocks_data,
|
||||
max_seq_len,
|
||||
false, // is_scale_channel_wise
|
||||
cache_quant_type == "cache_fp8", // is_fp8
|
||||
cache_quant_type,
|
||||
stream,
|
||||
const_cast<paddle::Tensor*>(&key_cache),
|
||||
const_cast<paddle::Tensor*>(&value_cache));
|
||||
|
@@ -56,6 +56,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -103,5 +104,6 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -98,5 +99,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
||||
@@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
|
||||
const bool causal,
|
||||
const bool is_decoder,
|
||||
const bool enable_prefill,
|
||||
const std::string& cache_quant_type_str,
|
||||
cudaStream_t& stream,
|
||||
paddle::Tensor* out);
|
||||
|
@@ -441,6 +441,15 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
|
||||
PD_THROW("not support the group_size", group_size); \
|
||||
}
|
||||
|
||||
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
|
||||
if (is_dynamic_cfp8) { \
|
||||
constexpr bool IsDynamicC8 = true; \
|
||||
__VA_ARGS__ \
|
||||
} else { \
|
||||
constexpr bool IsDynamicC8 = false; \
|
||||
__VA_ARGS__ \
|
||||
}
|
||||
|
||||
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
|
||||
if (group_size == 8) { \
|
||||
constexpr size_t GROUP_SIZE = 8; \
|
||||
|
@@ -231,6 +231,17 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.kv_signal_metadata,
|
||||
layer.layer_id + self.start_layer_index,
|
||||
)
|
||||
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
|
||||
if cache_quant_type_str == "block_wise_fp8":
|
||||
cache_k = forward_meta.caches[4 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
|
||||
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
|
||||
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
|
||||
else:
|
||||
cache_k = forward_meta.caches[2 * layer.layer_id]
|
||||
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
|
||||
cache_k_scales = getattr(layer, "cache_k_scale", None)
|
||||
cache_v_scales = getattr(layer, "cache_v_scale", None)
|
||||
|
||||
if self.use_output:
|
||||
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
|
||||
@@ -269,8 +280,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
|
||||
append_attention_with_output(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
cache_k,
|
||||
cache_v,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -293,8 +304,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
cache_k_scales,
|
||||
cache_v_scales,
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
@@ -325,8 +336,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
else:
|
||||
res = append_attention(
|
||||
qkv,
|
||||
forward_meta.caches[2 * layer.layer_id],
|
||||
forward_meta.caches[2 * layer.layer_id + 1],
|
||||
cache_k,
|
||||
cache_v,
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
@@ -348,8 +359,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.attn_mask,
|
||||
layer.qkv_bias,
|
||||
layer.qkv_scale,
|
||||
getattr(layer, "cache_k_scale", None),
|
||||
getattr(layer, "cache_v_scale", None),
|
||||
cache_k_scales,
|
||||
cache_v_scales,
|
||||
getattr(layer, "cache_k_out_scale", None),
|
||||
getattr(layer, "cache_v_out_scale", None),
|
||||
getattr(layer, "cache_k_zp", None),
|
||||
|
@@ -33,6 +33,7 @@ class KvCacheQuantzationTypes(str, Enum):
|
||||
|
||||
INT8 = "int8"
|
||||
FP8 = "float8_e4m3fn"
|
||||
BLOCK_WISE_FP8 = "block_wise_fp8"
|
||||
INT8_ZP = "int8_zp"
|
||||
INT4_ZP = "int4_zp"
|
||||
FP8_ZP = "float8_e4m3fn_zp"
|
||||
@@ -62,7 +63,11 @@ class KvCacheQuantConfig(QuantConfigBase):
|
||||
|
||||
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
|
||||
self.max_bound = 127.0
|
||||
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
|
||||
elif (
|
||||
self.quant_type == KvCacheQuantzationTypes.FP8
|
||||
or self.quant_type == KvCacheQuantzationTypes.FP8_ZP
|
||||
or self.quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8
|
||||
):
|
||||
self.max_bound = 448.0
|
||||
elif self.quant_type == KvCacheQuantzationTypes.INT4_ZP:
|
||||
self.max_bound = 7.0
|
||||
@@ -178,9 +183,14 @@ class KVCacheMethodBase(QuantMethodBase):
|
||||
layer.cache_quant_type_str = "cache_int4_zp"
|
||||
layer.quant_max_bound = 7.0
|
||||
layer.quant_min_bound = -7.0
|
||||
elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8:
|
||||
layer.cache_quant_type_str = "block_wise_fp8"
|
||||
layer.quant_max_bound = 448.0
|
||||
layer.quant_min_bound = -448.0
|
||||
else:
|
||||
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
|
||||
|
||||
if "block_wise" not in layer.cache_quant_type_str:
|
||||
self.load_scale(layer, state_dict)
|
||||
if self.cache_quant_config.has_zero_point:
|
||||
self.load_zp(layer, state_dict)
|
||||
|
@@ -1023,6 +1023,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
|
||||
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
|
||||
)
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
|
||||
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
|
||||
|
||||
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
|
||||
@@ -1050,6 +1052,17 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
fill_value=0,
|
||||
dtype=cache_type,
|
||||
)
|
||||
if kv_cache_quant_type == "block_wise_fp8":
|
||||
cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
|
||||
shape=kv_cache_scale_shape,
|
||||
fill_value=0,
|
||||
dtype=paddle.get_default_dtype(),
|
||||
)
|
||||
self.share_inputs["caches"] = list(cache_kvs.values())
|
||||
for value in cache_kvs.values():
|
||||
del value
|
||||
|
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import time
|
||||
import unittest
|
||||
|
||||
@@ -20,6 +21,7 @@ import paddle
|
||||
from paddle.incubate.nn.functional import fused_rms_norm
|
||||
|
||||
paddle.seed(10)
|
||||
np.random.seed(10)
|
||||
|
||||
|
||||
class RopeEmbedding:
|
||||
@@ -334,7 +336,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.name = "TestAppendGroupQueryAttnWithRope"
|
||||
self.place = paddle.CUDAPlace(0)
|
||||
self.batch_size = 1
|
||||
self.q_num_head = 12
|
||||
self.q_num_head = 16
|
||||
self.kv_num_head = 2
|
||||
self.seq_len = 64
|
||||
self.max_dec_len = 64
|
||||
@@ -347,9 +349,10 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.max_seq_len = self.seq_len + self.max_dec_len
|
||||
self.softmax_scale = self.dim_head**-0.5
|
||||
self.rope_theta = 10000
|
||||
self.dtype = "float16"
|
||||
self.dtype = "bfloat16"
|
||||
self.use_qk_norm = True
|
||||
self.use_mask_offset = False
|
||||
self.use_dynamic_quant = False
|
||||
self.init_tensor()
|
||||
|
||||
def init_tensor(self):
|
||||
@@ -391,8 +394,23 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.scale = 1.0 / np.sqrt(self.dim_head)
|
||||
if self.use_dynamic_quant:
|
||||
self.cache_scale_shape = (
|
||||
self.max_block_num,
|
||||
self.kv_num_head,
|
||||
self.blocksize,
|
||||
)
|
||||
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8")
|
||||
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8")
|
||||
self.cache_k_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
|
||||
self.cache_v_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
|
||||
self.key_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype)
|
||||
self.value_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype)
|
||||
else:
|
||||
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
|
||||
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
|
||||
self.key_cache_scale = None
|
||||
self.value_cache_scale = None
|
||||
self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32")
|
||||
for i in range(self.batch_size):
|
||||
need_block_num = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
|
||||
@@ -415,6 +433,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
|
||||
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
|
||||
paddle.disable_static()
|
||||
print("use_dynamic_quant: ", self.use_dynamic_quant)
|
||||
self.token_num = self.seq_len * self.batch_size
|
||||
q, k, v, qkv = get_qkv_and_qkv_concat_tensor(
|
||||
self.batch_size,
|
||||
@@ -472,18 +491,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
self.blocksize,
|
||||
speculate_max_draft_token_num + 1,
|
||||
)
|
||||
if self.use_dynamic_quant:
|
||||
cache_quant_type = "block_wise_fp8"
|
||||
else:
|
||||
cache_quant_type = "none"
|
||||
|
||||
# Warm up
|
||||
WARM_UP = 1
|
||||
RUN_TIME = 2
|
||||
for i in range(WARM_UP + RUN_TIME):
|
||||
if i == WARM_UP:
|
||||
paddle.device.synchronize()
|
||||
start_time = time.time()
|
||||
out = append_attention(
|
||||
qkv,
|
||||
self.cache_k,
|
||||
self.cache_v,
|
||||
if self.use_dynamic_quant:
|
||||
qkv_copy = copy.deepcopy(qkv)
|
||||
append_attention(
|
||||
qkv_copy,
|
||||
self.cache_k_T,
|
||||
self.cache_v_T,
|
||||
self.seq_lens_encoder,
|
||||
self.seq_lens_decoder,
|
||||
self.seq_lens_this_time,
|
||||
@@ -519,7 +537,69 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
k_norm_weight, # k_norm_weight
|
||||
1e-6,
|
||||
"fp16",
|
||||
"none", # cache_quant_type
|
||||
"none",
|
||||
self.use_neox_rotary_style,
|
||||
False,
|
||||
self.max_seq_len,
|
||||
0.0, # quant_min_bound
|
||||
0.0, # quant_max_bound
|
||||
-1, # out_linear_in_scale
|
||||
64, # encoder_block_shape_q
|
||||
16, # decoder_block_shape_q
|
||||
32768, # max_partition_size
|
||||
32768, # encoder_max_partition_size
|
||||
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
|
||||
True, # causal
|
||||
False, # speculate_decoder
|
||||
)
|
||||
|
||||
# Warm up
|
||||
WARM_UP = 1
|
||||
RUN_TIME = 2
|
||||
for i in range(WARM_UP + RUN_TIME):
|
||||
if i == WARM_UP:
|
||||
paddle.device.synchronize()
|
||||
start_time = time.time()
|
||||
out = append_attention(
|
||||
qkv,
|
||||
self.cache_k,
|
||||
self.cache_v,
|
||||
self.seq_lens_encoder,
|
||||
self.seq_lens_decoder,
|
||||
self.seq_lens_this_time,
|
||||
self.padding_offset,
|
||||
self.cum_offset,
|
||||
self.block_tables,
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks,
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
self.decoder_batch_ids,
|
||||
self.decoder_tile_ids_per_batch,
|
||||
self.decoder_num_blocks_cpu,
|
||||
self.max_len_tensor_cpu,
|
||||
max_len_kv,
|
||||
self.rope_emb, # rope_emb
|
||||
None, # attn_mask
|
||||
None, # qkv_bias
|
||||
None, # qkv_out_scales
|
||||
self.key_cache_scale, # cache_k_quant_scales
|
||||
self.value_cache_scale, # cache_v_quant_scales
|
||||
None, # cache_k_dequant_scales
|
||||
None, # cache_v_dequant_scales
|
||||
None, # cache_k_zp
|
||||
None, # cache_v_zp
|
||||
None, # linear_shift
|
||||
None, # linear_smooth
|
||||
self.mask_offset, # mask_offset
|
||||
None, # kv_signal_data
|
||||
q_norm_weight, # q_norm_weight
|
||||
k_norm_weight, # k_norm_weight
|
||||
1e-6,
|
||||
"fp16",
|
||||
cache_quant_type,
|
||||
self.use_neox_rotary_style,
|
||||
False,
|
||||
self.max_seq_len,
|
||||
@@ -537,13 +617,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
paddle.device.synchronize()
|
||||
end_time = time.time()
|
||||
print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms")
|
||||
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
|
||||
self.cache_k,
|
||||
self.cache_v,
|
||||
self.batch_size,
|
||||
self.block_tables,
|
||||
self.seq_len,
|
||||
)
|
||||
np.testing.assert_allclose(
|
||||
out.numpy(),
|
||||
out_.numpy(),
|
||||
@@ -572,6 +645,15 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
if self.use_mask_offset:
|
||||
print("encoder mask_offset: ", self.mask_offset)
|
||||
self.cmp_append_attention(attn_mask=self.attention_mask)
|
||||
if self.use_dynamic_quant:
|
||||
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
|
||||
self.cache_k_T,
|
||||
self.cache_v_T,
|
||||
self.batch_size,
|
||||
self.block_tables,
|
||||
self.seq_len,
|
||||
)
|
||||
else:
|
||||
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
|
||||
self.cache_k,
|
||||
self.cache_v,
|
||||
@@ -613,10 +695,10 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
|
||||
class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
|
||||
def setUp(self):
|
||||
paddle.disable_static()
|
||||
self.name = "TestAppendGroupQueryAttnWithRope"
|
||||
self.name = "TestAppendGroupQueryAttnWithNeoXRope"
|
||||
self.place = paddle.CUDAPlace(0)
|
||||
self.batch_size = 1
|
||||
self.q_num_head = 12
|
||||
self.q_num_head = 16
|
||||
self.kv_num_head = 2
|
||||
self.seq_len = 64
|
||||
self.max_dec_len = 64
|
||||
@@ -632,6 +714,33 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
|
||||
self.dtype = "float16"
|
||||
self.use_qk_norm = False
|
||||
self.use_mask_offset = True
|
||||
self.use_dynamic_quant = False
|
||||
self.init_tensor()
|
||||
|
||||
|
||||
class TestAppendGroupQueryAttnWithRopeDyCfp8(TestAppendGroupQueryAttnWithRope):
|
||||
def setUp(self):
|
||||
paddle.disable_static()
|
||||
self.name = "TestAppendGroupQueryAttnWithRopeDyCfp8"
|
||||
self.place = paddle.CUDAPlace(0)
|
||||
self.batch_size = 1
|
||||
self.q_num_head = 16
|
||||
self.kv_num_head = 2
|
||||
self.seq_len = 64
|
||||
self.max_dec_len = 64
|
||||
self.dim_head = 128
|
||||
self.q_hid_dim = self.q_num_head * self.dim_head
|
||||
self.kv_hid_dim = self.kv_num_head * self.dim_head
|
||||
self.blocksize = 64
|
||||
self.use_neox_rotary_style = False
|
||||
# max_seq_len = self.seq_len + self.max_dec_len
|
||||
self.max_seq_len = self.seq_len + self.max_dec_len
|
||||
self.softmax_scale = self.dim_head**-0.5
|
||||
self.rope_theta = 10000
|
||||
self.dtype = "bfloat16"
|
||||
self.use_qk_norm = True
|
||||
self.use_mask_offset = False
|
||||
self.use_dynamic_quant = True
|
||||
self.init_tensor()
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user