supports dynamic Cfp8 (#3767)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled

* supports dynamic Cfp8

* add unittest
This commit is contained in:
lzy
2025-09-08 11:41:29 +08:00
committed by GitHub
parent b5e20e3015
commit af49b81ffd
20 changed files with 1417 additions and 225 deletions

View File

@@ -140,8 +140,8 @@ void AppendAttentionKernel(
key_cache,
value_cache,
attn_mask,
cache_k_dequant_scales,
cache_v_dequant_scales,
cache_quant_type_str == "block_wise_fp8" ? cache_k_quant_scales : cache_k_dequant_scales,
cache_quant_type_str == "block_wise_fp8" ? cache_v_quant_scales : cache_v_dequant_scales,
cache_k_zp,
cache_v_zp,
out_linear_shifts,

View File

@@ -32,14 +32,15 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise = false,
bool IsFP8=false>
bool IsFP8 = false,
bool IsDynamicC8 = false>
__global__ void multi_query_append_attention_c8_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads]
const T *__restrict__ cache_v_scale, // [num_kv_heads]
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int *__restrict__ seq_lens,
@@ -91,9 +92,10 @@ __global__ void multi_query_append_attention_c8_kernel(
return;
}
T cache_k_scale_reg[num_frags_y * 4];
T cache_v_scale_reg[num_frags_y * 2];
if (is_scale_channel_wise) {
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
@@ -114,6 +116,7 @@ __global__ void multi_query_append_attention_c8_kernel(
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
}
}
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
@@ -201,6 +204,13 @@ __global__ void multi_query_append_attention_c8_kernel(
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
T* k_smem_scale = nullptr;
T* v_smem_scale = nullptr;
if constexpr (IsDynamicC8) {
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale = k_smem_scale + num_frags_z * 16;
}
const uint32_t num_iterations = div_up(
@@ -282,10 +292,22 @@ __global__ void multi_query_append_attention_c8_kernel(
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale,
cache_k_scale_reg,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
@@ -318,6 +340,7 @@ __global__ void multi_query_append_attention_c8_kernel(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
const int ori_kv_idx_base = kv_idx_base;
kv_idx_base += num_frags_z * 16;
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -336,6 +359,18 @@ __global__ void multi_query_append_attention_c8_kernel(
chunk_end,
const_k_offset);
commit_group();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale,
cache_v_scale_reg,
block_table_now,
cache_v_scale,
ori_kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
@@ -346,7 +381,9 @@ __global__ void multi_query_append_attention_c8_kernel(
BLOCK_SIZE,
T,
CacheT,
is_scale_channel_wise, IsFP8>(
is_scale_channel_wise,
IsFP8,
IsDynamicC8>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
__syncthreads();
@@ -463,14 +500,15 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise=false,
bool IsFP8=false>
bool IsFP8 = false,
bool IsDynamicC8 = false>
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads, head_dim]
const T *__restrict__ cache_v_scale, // [num_kv_heads, head_dim]
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const int *__restrict__ seq_lens,
@@ -522,9 +560,10 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
if (q_len <= 0) {
return;
}
T cache_k_scale_reg[num_frags_y * 4];
T cache_v_scale_reg[num_frags_y * 2];
if (is_scale_channel_wise) {
T cache_k_scale_reg[IsDynamicC8 ? num_frags_z * 2 : num_frags_y * 4];
T cache_v_scale_reg[IsDynamicC8 ? num_frags_z * 4 : num_frags_y * 2];
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
int scale_col_base = threadIdx.x % 4 * 2 + kv_head_idx * HEAD_DIM;
const T *cache_k_scale_cur_head = cache_k_scale + scale_col_base;
for (int i = 0; i < num_frags_y; ++i) {
@@ -545,6 +584,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
cache_k_scale_reg[0] = cache_k_scale[kv_head_idx];
cache_v_scale_reg[0] = cache_v_scale[kv_head_idx];
}
}
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
uint32_t kv_len = seq_lens_kv[batch_id];
@@ -634,6 +674,13 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
T* k_smem_scale = nullptr;
T* v_smem_scale = nullptr;
if constexpr (IsDynamicC8) {
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
}
const uint32_t num_iterations = div_up(
CAUSAL
@@ -716,11 +763,23 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
commit_group();
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale,
cache_k_scale_reg,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8>(
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
@@ -753,6 +812,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
s_frag, o_frag, m_frag, d_frag);
__syncthreads();
const uint32_t ori_kv_idx_base = kv_idx_base;
kv_idx_base += NUM_WARP_KV * num_frags_z * 16;
produce_k_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -771,6 +831,18 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
chunk_end,
const_k_offset);
commit_group();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale,
cache_v_scale_reg,
block_table_now,
cache_v_scale,
ori_kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
@@ -781,7 +853,9 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
BLOCK_SIZE,
T,
CacheT,
is_scale_channel_wise, IsFP8>(
is_scale_channel_wise,
IsFP8,
IsDynamicC8>(
&v_smem, &v_smem_offset_r, s_frag, o_frag, d_frag, cache_v_scale_reg);
__syncthreads();
@@ -895,7 +969,8 @@ template <typename T,
uint32_t NUM_WARP_Q,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool IsFP8=false>
bool IsFP8 = false,
bool IsDynamicC8 = false>
void MultiQueryAppendC8Attention(
const AppendAttnMetaData &meta_data,
const paddle::Tensor &qkv,
@@ -953,7 +1028,8 @@ void MultiQueryAppendC8Attention(
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16;
constexpr uint32_t smem_size =
num_warps * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
num_frags_z * 16 * sizeof(T) * 2;
auto split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
uint8_t,
@@ -970,7 +1046,9 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false, IsFP8>;
false,
IsFP8,
IsDynamicC8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
@@ -988,7 +1066,9 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true, IsFP8>;
true,
IsFP8,
IsDynamicC8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
@@ -1022,7 +1102,9 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false, IsFP8>;
false,
IsFP8,
IsDynamicC8>;
if (is_scale_channel_wise) {
nosplit_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
@@ -1040,7 +1122,9 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true, IsFP8>;
true,
IsFP8,
IsDynamicC8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
@@ -1218,7 +1302,8 @@ void MultiQueryAppendC8Attention(
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
constexpr uint32_t smem_size =
num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2;
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(uint8_t) * 2 +
NUM_WARP_KV * num_frags_z * 16 * sizeof(T) * 2;
auto split_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
@@ -1235,7 +1320,9 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false, IsFP8>;
false,
IsFP8,
IsDynamicC8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
@@ -1253,7 +1340,9 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true, IsFP8>;
true,
IsFP8,
IsDynamicC8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
@@ -1295,7 +1384,9 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
false, IsFP8>;
false,
IsFP8,
IsDynamicC8>;
if (is_scale_channel_wise) {
nosplit_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
@@ -1313,7 +1404,9 @@ void MultiQueryAppendC8Attention(
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true, IsFP8>;
true,
IsFP8,
IsDynamicC8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
@@ -1546,6 +1639,7 @@ void CascadeAppendAttentionC8Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out) {
const auto token_num = meta_data.token_nums;
@@ -1554,6 +1648,7 @@ void CascadeAppendAttentionC8Kernel(
const auto num_heads = meta_data.q_num_heads;
const auto group_size = meta_data.q_num_heads / meta_data.kv_num_heads;
const auto head_dim = meta_data.head_dims;
bool is_dynamic_cfp8 = cache_quant_type_str == "block_wise_fp8";
DISPATCH_CAUSAL(
causal,
@@ -1572,6 +1667,7 @@ void CascadeAppendAttentionC8Kernel(
BLOCK_SIZE,
{DISPATCH_BLOCKSHAPE_Q(
block_shape_q, BLOCK_SHAPE_Q, NUM_WARP_Q, {
DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, {
MultiQueryAppendC8Attention<T,
GROUP_SIZE,
HEAD_DIM,
@@ -1580,7 +1676,9 @@ void CascadeAppendAttentionC8Kernel(
BLOCK_SHAPE_Q,
NUM_WARP_Q,
OutT,
ENABLE_PREFILL, IsFP8>(
ENABLE_PREFILL,
IsFP8,
IsDynamicC8>(
meta_data,
qkv,
cache_k,
@@ -1610,5 +1708,5 @@ void CascadeAppendAttentionC8Kernel(
is_decoder,
stream,
out);
})})})})})})
})})})})})})})
}

View File

@@ -384,6 +384,113 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale(
T* k_smem_scale,
T* cache_k_reg,
const int* block_table_now,
const T* cache_k_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
if (tid < block_size) {
k_smem_scale[tid] = cache_k_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
} else {
k_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
}
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale(
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
if (tid < block_size) {
v_smem_scale[tid] = cache_v_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
} else {
v_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
}
}
}
template <SharedMemFillMode fill_mode,
uint32_t num_warps,
uint32_t block_size,
@@ -816,7 +923,8 @@ template <uint32_t num_frags_x,
typename T,
typename CacheT,
bool is_scale_channel_wise = false,
bool IsFP8=false>
bool IsFP8 = false,
bool IsDynamicC8 = false>
__device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
uint32_t* q_smem_offset_r,
smem_t* k_smem,
@@ -860,6 +968,7 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
const int scale_col = (ky * 2 + fy) * 4;
b_frag_dq_T[0] *= cache_k_scale[scale_col];
@@ -876,6 +985,12 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
b_frag_dq_T[b_i] *= cache_k_scale[0];
}
}
} else {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_k_scale[fz * 2 + b_i / 4];
}
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
if (ky == 0 && fy == 0) {
@@ -1093,7 +1208,9 @@ template <uint32_t num_frags_x,
uint32_t block_size,
typename T,
typename CacheT,
bool is_scale_channel_wise = false, bool IsFP8=false>
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
__device__ __forceinline__ void compute_sfm_v_c8(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
@@ -1135,6 +1252,7 @@ __device__ __forceinline__ void compute_sfm_v_c8(
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
@@ -1146,6 +1264,17 @@ __device__ __forceinline__ void compute_sfm_v_c8(
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
}
} else {
const int scale_col = (kz * 2 + fz) * 4;
b_frag_dq_T[0] *= cache_v_scale[scale_col];
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_v_scale[scale_col];
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
mma_sync_m16n16k16_row_col_f16f16f32<T>(
@@ -1171,7 +1300,9 @@ template <uint32_t num_frags_x,
uint32_t block_size,
typename T,
typename CacheT,
bool is_scale_channel_wise = false, bool IsFP8=false>
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
__device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
smem_t* v_smem,
uint32_t* v_smem_offset_r,
@@ -1215,6 +1346,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
@@ -1226,6 +1358,17 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
}
} else {
const int scale_col = (kz * 2 + fz) * 4;
b_frag_dq_T[0] *= cache_v_scale[scale_col];
b_frag_dq_T[1] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[2] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[3] *= cache_v_scale[scale_col + 3];
b_frag_dq_T[4] *= cache_v_scale[scale_col];
b_frag_dq_T[5] *= cache_v_scale[scale_col + 1];
b_frag_dq_T[6] *= cache_v_scale[scale_col + 2];
b_frag_dq_T[7] *= cache_v_scale[scale_col + 3];
}
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m: num_frags_x * 16
mma_sync_m16n16k16_row_col_f16f16f32<T>(

View File

@@ -103,6 +103,7 @@ void CascadeAppendAttentionC8Kernel(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -264,9 +265,10 @@ void CascadeAppendAttentionKernel(
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out);
} else if (cache_quant_type_str == "cache_fp8") {
} else if (cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
CascadeAppendAttentionC8Kernel<T, OutT, true>(meta_data,
qkv,
cache_k,
@@ -299,6 +301,7 @@ void CascadeAppendAttentionKernel(
causal,
is_decoder,
enable_prefill,
cache_quant_type_str,
stream,
out);
} else if (cache_quant_type_str == "cache_int4_zp") {

View File

@@ -674,6 +674,294 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
}
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=true>
__global__ void append_decode_cache_int8_rope_qk_norm_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
// block_size, head_size // 2]
uint8_t* __restrict__ value_cache, // [num_blocks, kv_num_heads,
// block_size, head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ batch_id_per_token, // [num_tokens]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb,
const float* __restrict__ sin_emb,
T* __restrict__ cache_k_scale,
T* __restrict__ cache_v_scale,
const float* q_norm_weight,
const float* k_norm_weight,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int block_size,
const float max_bound,
const float min_bound,
const int kv_num_heads,
const bool rope_3d,
const float rms_norm_eps) {
static_assert(HeadDim == 128, "just support HeadDim be 128 now!");
static_assert(VecSize == 4, "just support VecSize be 4 now, 32 * 4!");
constexpr int NUM_WARPS = 4;
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane_id = tid % 32;
const int bid = blockIdx.x, head_idx = blockIdx.y * NUM_WARPS + wid;
int q_head_idx, k_head_idx, v_idx;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * HeadDim;
constexpr int half_head_size = HeadDim / 2;
const int start_token_idx = cu_seqlens_q[bid];
if (seq_lens_encoder[bid] > 0) return;
const int write_seq_id = seq_lens[bid];
if (write_seq_id == 0) return;
const int* block_table_now = nullptr;
block_table_now = block_tables + bid * max_blocks_per_seq;
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
int cache_offset;
if (head_idx < num_heads) {
cache_offset = 0;
} else if (head_idx < num_heads + 2 * kv_num_heads) {
cache_offset = block_idx * kv_num_heads * block_size + (head_idx - num_heads) % kv_num_heads * block_size + block_offset;
}
T *cache_k_scale_now = cache_k_scale + cache_offset;
T *cache_v_scale_now = cache_v_scale + cache_offset;
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadOutScaleT = AlignedVector<float, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, HalfVecSize>;
LoadT src_vec;
LoadBiasT out_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
T* qkv_out_now = qkv_out + start_token_idx * hidden_size;
#pragma unroll
for (uint32_t head_bias = lane_id * VecSize; head_bias < HeadDim;
head_bias += 32 * VecSize) {
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, VecSize>(&qkv_now[bias_idx], &src_vec);
// q rope
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, HalfVecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, HalfVecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < HalfVecSize; i++) {
// dequant + add_bias + rope
float input_left = static_cast<float>(src_vec[2 * i]);
float input_right = static_cast<float>(src_vec[2 * i + 1]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec[2 * i] =
static_cast<T>(tmp1);
out_vec[2 * i + 1] =
static_cast<T>(tmp2);
}
// qk norm
if (q_norm_weight) {
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / HeadDim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
LoadOutScaleT q_norm_vec;
Load<float, VecSize>(&q_norm_weight[lane_id * VecSize], &q_norm_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * q_norm_vec[i]);
}
}
Store<T, VecSize>(out_vec, &qkv_out_now[bias_idx]);
}
} else if (head_idx < num_heads + 2 * kv_num_heads) {
// k
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
const uint32_t kv_head_idx = (head_idx - num_heads) % kv_num_heads;
if (block_offset == 0) {
// pad zero for this kv_head_idx for this block
LoadPadKVT pad_cache_vec;
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
if (head_idx < num_heads + kv_num_heads) {
constexpr int num_vecs_per_head_dim = HeadDim / KV_VEC_SIZE;
constexpr int num_token_each_time = 32 / num_vecs_per_head_dim;
const uint32_t tgt_idx =
(block_idx * kv_num_heads + kv_head_idx) * block_size * HeadDim +
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
for (int block_i = lane_id / num_vecs_per_head_dim;
block_i < block_size;
block_i += num_token_each_time) {
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
&key_cache[tgt_idx + block_i * HeadDim]);
}
} else {
const int num_vecs_per_head_dim = block_size / KV_VEC_SIZE;
const int num_token_each_time = 32 / num_vecs_per_head_dim;
const uint32_t tgt_idx =
(block_idx * kv_num_heads + kv_head_idx) * HeadDim * block_size +
lane_id % num_vecs_per_head_dim * KV_VEC_SIZE;
for (int block_i = lane_id / num_vecs_per_head_dim; block_i < HeadDim;
block_i += num_token_each_time) {
Store<uint8_t, KV_VEC_SIZE>(
pad_cache_vec, &value_cache[tgt_idx + block_i * block_size]);
}
}
__syncwarp();
}
constexpr int K_VEC_SIZE = 4;
constexpr int HALF_K_VEC_SIZE = 2;
using LoadKVResT = AlignedVector<uint8_t, K_VEC_SIZE>;
using LoadKVT = AlignedVector<uint8_t, HALF_K_VEC_SIZE>;
using LoadT = AlignedVector<T, HALF_K_VEC_SIZE>;
using LoadBiasT = AlignedVector<T, HALF_K_VEC_SIZE>;
using LoadOutScaleT = AlignedVector<float, HALF_K_VEC_SIZE>;
using LoadEmbT = AlignedVector<float, 1>;
LoadKVResT cache_vec;
LoadT src_vec1, src_vec2;
LoadBiasT out_vec1, out_vec2;
LoadEmbT cos_emb_vec1, cos_emb_vec2;
LoadEmbT sin_emb_vec1, sin_emb_vec2;
const T* qkv_now = quant_qkv + start_token_idx * hidden_size;
const int head_bias = lane_id / 4 * 16 + lane_id % 4 * 2;
const int bias_idx = head_idx * HeadDim + head_bias;
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx], &src_vec1);
Load<T, HALF_K_VEC_SIZE>(&qkv_now[bias_idx + 8], &src_vec2);
T scale = T(1.0f);
const int k_head_idx = head_idx - num_heads;
const int v_head_idx = head_idx - num_heads - kv_num_heads;
if (head_idx < num_heads + kv_num_heads) {
const uint32_t emb_idx = write_seq_id * half_head_size + head_bias / 2;
const uint32_t new_emb_idx = rope_3d ? emb_idx + bid * max_seq_len * HeadDim : emb_idx;
Load<float, 1>(&cos_emb[new_emb_idx], &cos_emb_vec1);
Load<float, 1>(&cos_emb[new_emb_idx + 4], &cos_emb_vec2);
Load<float, 1>(&sin_emb[new_emb_idx], &sin_emb_vec1);
Load<float, 1>(&sin_emb[new_emb_idx + 4], &sin_emb_vec2);
}
float input_left = static_cast<float>(src_vec1[0]);
float input_right = static_cast<float>(src_vec1[1]);
if (head_idx < num_heads + kv_num_heads) {
float cos_tmp = cos_emb_vec1[0];
float sin_tmp = sin_emb_vec1[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec1[0] =
static_cast<T>(tmp1);
out_vec1[1] =
static_cast<T>(tmp2);
} else {
out_vec1[0] = src_vec1[0];
out_vec1[1] = src_vec1[1];
}
// rope
input_left = static_cast<float>(src_vec2[0]);
input_right = static_cast<float>(src_vec2[1]);
if (head_idx < num_heads + kv_num_heads) {
float cos_tmp = cos_emb_vec2[0];
float sin_tmp = sin_emb_vec2[0];
float tmp1 = input_left * cos_tmp - input_right * sin_tmp;
float tmp2 = input_right * cos_tmp + input_left * sin_tmp;
thread_m2 += tmp1 * tmp1 + tmp2 * tmp2;
out_vec2[0] = static_cast<T>(tmp1);
out_vec2[1] = static_cast<T>(tmp2);
} else {
out_vec2[0] = src_vec2[0];
out_vec2[1] = src_vec2[1];
}
if (k_norm_weight) {
if (head_idx < num_heads + kv_num_heads) {
LoadOutScaleT k_norm_vec1, k_norm_vec2;
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias], &k_norm_vec1);
Load<float, HALF_K_VEC_SIZE>(&k_norm_weight[head_bias + 8], &k_norm_vec2);
// qk norm
WelfordWarpAllReduce<float, 32>(thread_m2, &warp_m2);
float row_variance =
max(warp_m2 / HeadDim, 0.0f);
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
out_vec1[i] = static_cast<T>(static_cast<float>(out_vec1[i]) * row_inv_var * k_norm_vec1[i]);
out_vec2[i] = static_cast<T>(static_cast<float>(out_vec2[i]) * row_inv_var * k_norm_vec2[i]);
}
}
}
// reduce max, 1 head per warp
T local_max = -INFINITY;
#pragma unroll
for (int i = 0; i < HALF_K_VEC_SIZE; i++) {
local_max = __hmax(local_max, __habs(out_vec1[i]));
local_max = __hmax(local_max, __habs(out_vec2[i]));
}
#pragma unroll
for (int m_offset = 16; m_offset > 1; m_offset /= 2) {
local_max = __hmax(local_max, __shfl_xor_sync(0xffffffff, local_max, m_offset));
}
scale = __hdiv(448, local_max);
if (lane_id == 0) {
if (head_idx < num_heads + kv_num_heads) {
cache_k_scale_now[0] = __hdiv(1, scale);
} else {
cache_v_scale_now[0] = __hdiv(1, scale);
}
}
#pragma unroll
for (uint32_t i = 0; i < HALF_K_VEC_SIZE; i++) {
cache_vec[i] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec1[i], max_bound, min_bound);
cache_vec[i + HALF_K_VEC_SIZE] = QuantToC8<T,true, IsFP8, RoundType>(scale, out_vec2[i], max_bound, min_bound);
}
if (head_idx < num_heads + kv_num_heads) {
const int start_block_16 =
block_offset / 16 * 16 + block_offset % 8 + lane_id / 4 % 2 * 8;
const uint32_t tgt_cache_idx =
block_idx * kv_num_heads * block_size * HeadDim +
kv_head_idx * block_size * HeadDim + start_block_16 * HeadDim +
lane_id / 4 / 2 * 32 + (block_offset % 16) / 8 * 16 + lane_id % 4 * 4;
Store<uint8_t, K_VEC_SIZE>(cache_vec, &key_cache[tgt_cache_idx]);
} else {
const uint32_t base_tgt_cache_idx =
block_idx * kv_num_heads * HeadDim * block_size +
kv_head_idx * HeadDim * block_size +
(lane_id / 4 * 16 + lane_id % 4 * 2) * block_size +
block_offset / 16 % 2 * 8 * block_size + block_offset / 16 / 2 * 32;
const uint32_t tgt_cache_idx1 = base_tgt_cache_idx +
block_offset % 8 / 2 * 4 // per 4
+ block_offset % 16 / 8 * 2 // per 2
+ block_offset % 2; // per 1
const uint32_t tgt_cache_idx2 = tgt_cache_idx1 + block_size;
const uint32_t tgt_cache_idx3 = tgt_cache_idx1 + 16;
const uint32_t tgt_cache_idx4 = tgt_cache_idx3 + block_size;
value_cache[tgt_cache_idx1] = cache_vec[0];
value_cache[tgt_cache_idx2] = cache_vec[1];
value_cache[tgt_cache_idx3] = cache_vec[2];
value_cache[tgt_cache_idx4] = cache_vec[3];
}
}
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128, bool is_scale_channel_wise=false, bool IsFP8=false>
__global__ void append_decode_cache_int8_rope_kernel(
const T* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,

View File

@@ -553,9 +553,40 @@ void DecoderWriteCacheWithRoPEKernel(
q_norm_weight ? q_norm_weight.get().data<float>() : nullptr,
k_norm_weight ? k_norm_weight.get().data<float>() : nullptr,
rms_norm_eps);
} else if (cache_quant_type_str == "block_wise_fp8") {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm not support cachekv quant yet");
"append_decode_cache_rope_qk_norm just supports cache_quant_type none/block_wise_fp8");
}
} else {
if (cache_quant_type_str == "none") {
@@ -686,6 +717,37 @@ void DecoderWriteCacheWithRoPEKernel(
stream,
use_neox_rotary_style,
rope_3d);
} else if (cache_quant_type_str == "block_wise_fp8") {
constexpr int num_warps = 4;
const int all_warps =
((num_heads + 2 * kv_num_heads) + num_warps - 1) / num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_, 4, 0, 128, false, true>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>((cache_v_scale.get().data<T>()))),
nullptr,
nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_decode_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),

View File

@@ -1232,6 +1232,411 @@ __global__ void append_write_cache_kv_c8_qkv(
}
}
template <typename T,
uint32_t num_frags_y,
uint32_t num_frags_z,
uint32_t HEAD_DIM,
uint32_t BLOCK_SIZE,
uint32_t NUM_WARPS,
bool is_need_kv_quant,
bool IsFP8 = true>
__global__ void append_write_cache_kv_c8_qkv_dynamic(
uint8_t *__restrict__ cache_k,
uint8_t *__restrict__ cache_v,
const T *__restrict__ qkv_input,
T *__restrict__ cache_k_scales, // [block_num, num_heads, block_size]
T *__restrict__ cache_v_scales, // [block_num, num_heads, block_size]
const int *__restrict__ batch_ids,
const int *__restrict__ tile_ids,
const int *__restrict__ seq_lens_this_time,
const int *__restrict__ seq_lens_decoder,
const int *__restrict__ batch_id_per_token,
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_tables,
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int kv_num_heads) {
constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b<T>();
constexpr uint32_t pad_len = BLOCK_SIZE;
const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z;
const T cache_k_scale = cache_k_scales[kv_head_idx];
const T cache_v_scale = cache_v_scales[kv_head_idx];
const uint32_t tid = threadIdx.x, wid = threadIdx.y;
const uint32_t batch_id = batch_ids[btid];
const uint32_t tile_id = tile_ids[btid];
const uint32_t seq_len_this_time = seq_lens_this_time[batch_id];
if (seq_len_this_time <= 0) {
return;
}
const int *block_table_now = nullptr;
block_table_now = block_tables + batch_id * max_blocks_per_seq;
const uint32_t num_rows_per_block =
NUM_WARPS * num_frags_z * 16; // BLOCK_SIZE
const uint32_t start_len = seq_lens_decoder[batch_id];
const uint32_t bf_pad_len = start_len % pad_len;
const uint32_t start_len_pad = start_len - bf_pad_len;
const uint32_t end_len = start_len + seq_len_this_time;
const uint32_t tile_start = start_len_pad + tile_id * num_rows_per_block;
int block_id = __ldg(&block_table_now[tile_start / BLOCK_SIZE]);
uint32_t chunk_start = tile_start + wid * num_frags_z * 16 + tid / 8;
const uint32_t start_token_idx = cu_seqlens_q[batch_id];
const uint32_t kv_batch_stride = (num_heads + 2 * kv_num_heads) * HEAD_DIM;
const uint32_t kv_h_stride = HEAD_DIM;
__shared__ T k_smem_ori[num_rows_per_block * HEAD_DIM];
__shared__ T v_smem_ori[num_rows_per_block * HEAD_DIM];
__shared__ T v_scale_smem[BLOCK_SIZE];
if (tile_start >= start_len) {
constexpr int KV_VEC_SIZE = 16 / sizeof(uint8_t); // 16
using LoadPadKVT = AlignedVector<uint8_t, KV_VEC_SIZE>;
// pad zero for this kv_head_idx for this block
LoadPadKVT pad_cache_vec;
*(reinterpret_cast<uint4*>(pad_cache_vec.val)) = make_uint4(0, 0, 0, 0);
// reset k
constexpr int num_vecs_per_head_k = HEAD_DIM / KV_VEC_SIZE;
constexpr int num_token_each_time_k = 32 / num_vecs_per_head_k;
uint32_t tgt_idx =
(block_id * kv_num_heads + kv_head_idx) * BLOCK_SIZE * HEAD_DIM +
tid % num_vecs_per_head_k * KV_VEC_SIZE;
for (int block_i = tid / num_vecs_per_head_k;
block_i < BLOCK_SIZE;
block_i += num_token_each_time_k) {
Store<uint8_t, KV_VEC_SIZE>(pad_cache_vec,
&cache_k[tgt_idx + block_i * HEAD_DIM]);
}
// reset v
const int num_vecs_per_head_v = BLOCK_SIZE / KV_VEC_SIZE;
const int num_token_each_time_v = 32 / num_vecs_per_head_v;
tgt_idx =
(block_id * kv_num_heads + kv_head_idx) * HEAD_DIM * BLOCK_SIZE +
tid % num_vecs_per_head_v * KV_VEC_SIZE;
for (int block_i = tid / num_vecs_per_head_v; block_i < HEAD_DIM;
block_i += num_token_each_time_v) {
Store<uint8_t, KV_VEC_SIZE>(
pad_cache_vec, &cache_v[tgt_idx + block_i * BLOCK_SIZE]);
}
}
smem_t k_smem(k_smem_ori);
smem_t v_smem(v_smem_ori);
uint32_t kv_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_z * 16 + tid / 8, tid % 8); // 4 * 8 per warp
/*
0 | 1
2 | 3
*/
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_z * 16 + 8 * (tid / 16) + tid % 8, (tid % 16) / 8);
constexpr uint32_t num_frags_v = num_frags_y / NUM_WARPS;
/*
0 | 2
1 | 3
*/
uint32_t v_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
tid % 16, wid * num_frags_v * 2 + tid / 16);
// load kv gmem to smem
const uint32_t real_start_token_idx = start_token_idx - bf_pad_len +
tile_id * num_rows_per_block +
wid * num_frags_z * 16 + tid / 8;
uint32_t k_read_idx = real_start_token_idx * kv_batch_stride +
(num_heads + kv_head_idx) * kv_h_stride +
tid % 8 * num_elems_per_128b<T>();
uint32_t v_read_idx = real_start_token_idx * kv_batch_stride +
(num_heads + kv_num_heads + kv_head_idx) * kv_h_stride +
tid % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t j = 0; j < 4; ++j) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y / 4;
++fy) { // (num_frags_y * 16) / (8 * num_elems_per_128b<T>())
if (chunk_start >= start_len && chunk_start < end_len) {
k_smem.load_128b_async<SharedMemFillMode::kNoFill>(
kv_smem_offset_w, qkv_input + k_read_idx, chunk_start < end_len);
v_smem.load_128b_async<SharedMemFillMode::kNoFill>(
kv_smem_offset_w, qkv_input + v_read_idx, chunk_start < end_len);
}
kv_smem_offset_w =
k_smem.advance_offset_by_column<8>(kv_smem_offset_w, fy);
k_read_idx += 8 * num_elems_per_128b<T>();
v_read_idx += 8 * num_elems_per_128b<T>();
}
kv_smem_offset_w =
k_smem.advance_offset_by_row<4, num_vecs_per_head>(kv_smem_offset_w) -
2 * num_frags_y;
chunk_start += 4;
k_read_idx +=
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
v_read_idx +=
4 * kv_batch_stride - 2 * num_frags_y * num_elems_per_128b<T>();
}
}
commit_group();
wait_group<0>();
__syncthreads();
// reduce scale
// 16 rows per warp
uint32_t kv_reduce_frag[4];
T *kv_reduce_frag_T = reinterpret_cast<T*>(kv_reduce_frag);
T k_local_max_value[num_frags_z * 2];
T v_local_max_value[num_frags_z * 2];
#pragma unroll
for (int i = 0; i < num_frags_z * 2; i++) {
k_local_max_value[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < num_frags_z * 2; i++) {
v_local_max_value[i] = -INFINITY;
}
const int num_kv_heads = gridDim.z;
const int scale_offset = block_id * num_kv_heads * BLOCK_SIZE + kv_head_idx * BLOCK_SIZE;
T *cache_k_scale_now = cache_k_scales + scale_offset;
T *cache_v_scale_now = cache_v_scales + scale_offset;
// k scale
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
// reduce per thread, 4 threads each row
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
#pragma unroll
for (int i = 0; i < 4; i++) {
k_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), k_local_max_value[fz * 2]);
}
#pragma unroll
for (int i = 0; i < 4; i++) {
k_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), k_local_max_value[fz * 2 + 1]);
}
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
}
// reduce per row
for (int i = 0; i < 2; i++) {
T local_max_value = __habs(k_local_max_value[fz * 2 + i]);
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
// used for quant
k_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
}
// store
if (tid % 4 == 0) {
const int offset_now = wid * num_frags_z * 16 + tid / 4;
// used for dequant
if (tile_start + offset_now >= start_len) {
if (tile_start + offset_now < end_len) {
cache_k_scale_now[offset_now] = __hdiv(1, k_local_max_value[fz * 2]);
} else {
cache_k_scale_now[offset_now] = 0;
}
}
if (tile_start + offset_now + 8 >= start_len) {
if (tile_start + offset_now + 8 < end_len) {
cache_k_scale_now[offset_now + 8] = __hdiv(1, k_local_max_value[fz * 2 + 1]);
} else {
cache_k_scale_now[offset_now + 8] = 0;
}
}
}
__syncthreads();
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
}
// v scale
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
// reduce per thread, 4 threads each row
v_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_reduce_frag);
#pragma unroll
for (int i = 0; i < 4; i++) {
v_local_max_value[fz * 2] = __hmax(__habs(kv_reduce_frag_T[i]), v_local_max_value[fz * 2]);
}
#pragma unroll
for (int i = 0; i < 4; i++) {
v_local_max_value[fz * 2 + 1] = __hmax(__habs(kv_reduce_frag_T[i + 4]), v_local_max_value[fz * 2 + 1]);
}
k_smem_offset_r = v_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
}
// reduce per row
for (int i = 0; i < 2; i++) {
T local_max_value = __habs(v_local_max_value[fz * 2 + i]);
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 2));
local_max_value = __hmax(local_max_value, __shfl_xor_sync(0xffffffff, local_max_value, 1));
v_local_max_value[fz * 2 + i] = __hdiv(448, local_max_value);
}
// store
if (tid % 4 == 0) {
const int offset_now = wid * num_frags_z * 16 + tid / 4;
// used for dequant
if (tile_start + offset_now >= start_len) {
if (tile_start + offset_now < end_len) {
cache_v_scale_now[offset_now] = __hdiv(1, v_local_max_value[fz * 2]);
v_scale_smem[offset_now] = v_local_max_value[fz * 2];
} else {
cache_v_scale_now[offset_now] = 0;
v_scale_smem[offset_now] = 0;
}
}
if (tile_start + offset_now + 8 >= start_len) {
if (tile_start + offset_now + 8 < end_len) {
cache_v_scale_now[offset_now + 8] = __hdiv(1, v_local_max_value[fz * 2 + 1]);
v_scale_smem[offset_now + 8] = v_local_max_value[fz * 2 + 1];
} else {
cache_v_scale_now[offset_now + 8] = 0;
v_scale_smem[offset_now + 8] = 0;
}
}
}
__syncthreads();
k_smem_offset_r -= 2 * num_frags_y; // num_frags_z = 1
}
__syncthreads();
// mask, quant, store
using LoadKVT = AlignedVector<uint8_t, 4>;
LoadKVT cache_vec1;
LoadKVT cache_vec2;
uint32_t chunk_start_k = tile_start + wid * num_frags_z * 16 + tid / 4;
uint32_t kv_frag[4];
const uint32_t write_n_stride = kv_num_heads * BLOCK_SIZE * HEAD_DIM;
const uint32_t write_h_stride = BLOCK_SIZE * HEAD_DIM;
const uint32_t write_b_stride = HEAD_DIM;
const uint32_t write_d_stride = BLOCK_SIZE;
uint32_t k_write_idx = block_id * write_n_stride +
kv_head_idx * write_h_stride +
(wid * num_frags_z * 16 + tid / 4) * write_b_stride +
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
uint32_t k_write_idx_now_z = k_write_idx + fz * 16 * write_b_stride;
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t k_write_idx_now = k_write_idx_now_z +
fy % 2 * 8 * write_b_stride +
fy / 2 * 32; // + fy % 2 * 16;
// load
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, kv_frag);
// quant
T *k_frag_T = reinterpret_cast<T *>(kv_frag);
if (bf_pad_len != 0) {
Load<uint8_t, 4>(cache_k + k_write_idx_now, &cache_vec1);
Load<uint8_t, 4>(cache_k + k_write_idx_now + 16, &cache_vec2);
}
#pragma unroll
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
uint8_t uint_quant_value;
if (chunk_start_k + (v_id / 4) * 8 >= start_len &&
chunk_start_k + (v_id / 4) * 8 < end_len) {
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(k_local_max_value[fz * 2 + v_id / 4], k_frag_T[v_id], 127.0f, -127.0f);
} else {
uint_quant_value = 0;
}
if (bf_pad_len != 0) {
if (v_id < 4) {
cache_vec1[v_id] |= uint_quant_value;
} else {
cache_vec2[v_id % 4] |= uint_quant_value;
}
} else {
if (v_id < 4) {
cache_vec1[v_id] = uint_quant_value;
} else {
cache_vec2[v_id - 4] = uint_quant_value;
}
}
}
// store
Store<uint8_t, 4>(cache_vec1, cache_k + k_write_idx_now);
Store<uint8_t, 4>(cache_vec2, cache_k + k_write_idx_now + 16);
k_smem_offset_r = k_smem.advance_offset_by_column<2>(k_smem_offset_r, fy);
}
k_smem_offset_r =
k_smem.advance_offset_by_row<16, num_vecs_per_head>(k_smem_offset_r) -
2 * num_frags_y;
chunk_start_k += 16;
}
uint32_t chunk_start_v = tile_start + tid % 4 * 2;
uint32_t v_write_idx = block_id * write_n_stride +
kv_head_idx * write_h_stride +
(wid * num_frags_v * 16 + tid / 4) * write_d_stride +
tid % 4 * 4; // 4 * int8 = 8 * int4 = 32bit
const uint32_t num_frags_z_v = num_frags_z * NUM_WARPS;
T v_scales[num_frags_z_v * 4];
for (int v_i = 0; v_i < num_frags_z_v; v_i++) {
const int offset = v_i * 16;
const int t_offset = tid % 4 * 2;
v_scales[v_i * 4] = v_scale_smem[offset + t_offset];
v_scales[v_i * 4 + 1] = v_scale_smem[offset + t_offset + 1];
v_scales[v_i * 4 + 2] = v_scale_smem[offset + t_offset + 8];
v_scales[v_i * 4 + 3] = v_scale_smem[offset + t_offset + 9];
}
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_v; ++fy) {
uint32_t v_write_idx_now_v = v_write_idx + fy * 16 * write_d_stride;
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z_v; ++fz) {
uint32_t v_write_idx_now = v_write_idx_now_v +
fz % 2 * 8 * write_d_stride +
fz / 2 * 32; // + fz % 2 * 16;
// load
v_smem.ldmatrix_m8n8x4_trans(v_smem_offset_r, kv_frag);
// quant
T *v_frag_T = reinterpret_cast<T *>(kv_frag);
if (bf_pad_len != 0) {
Load<uint8_t, 4>(cache_v + v_write_idx_now, &cache_vec1);
Load<uint8_t, 4>(cache_v + v_write_idx_now + 16, &cache_vec2);
}
#pragma unroll
for (uint32_t v_id = 0; v_id < 8; ++v_id) {
uint8_t uint_quant_value;
if (chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 >= start_len &&
chunk_start_v + v_id % 2 + (v_id % 4) / 2 * 8 < end_len) {
uint_quant_value = QuantToC8<T, is_need_kv_quant, IsFP8>(v_scales[fz * 4 + v_id % 4], v_frag_T[v_id], 127.0f, -127.0f);
// store now
} else {
uint_quant_value = 0;
}
if (bf_pad_len != 0) {
if (v_id < 4) {
cache_vec1[v_id] |= uint_quant_value;
} else {
cache_vec2[v_id % 4] |= uint_quant_value;
}
} else {
if (v_id < 4) {
cache_vec1[v_id] = uint_quant_value;
} else {
cache_vec2[v_id % 4] = uint_quant_value;
}
}
}
// store
Store<uint8_t, 4>(cache_vec1, cache_v + v_write_idx_now);
Store<uint8_t, 4>(cache_vec2, cache_v + v_write_idx_now + 16);
chunk_start_v += 16;
v_smem_offset_r =
k_smem.advance_offset_by_row<16, num_vecs_per_head>(v_smem_offset_r);
}
v_smem_offset_r = k_smem.advance_offset_by_column<2>(
v_smem_offset_r, wid * num_frags_v + fy) -
16 * num_frags_z_v * num_vecs_per_head;
chunk_start_v -= 16 * num_frags_z_v;
}
}
// Write Cache KV in Append
template <typename T,
uint32_t num_frags_y,
@@ -2006,10 +2411,11 @@ void CascadeAppendWriteCacheKVC8QKV(
int num_blocks_x_cpu,
int max_seq_len,
bool is_scale_channel_wise,
const bool is_fp8,
const std::string& cache_quant_type,
cudaStream_t &stream,
paddle::Tensor *cache_k_out,
paddle::Tensor *cache_v_out) {
using NV_TYPE = typename cascade_attn_type_traits<T>::type;
auto max_blocks_per_seq = meta_data.max_blocks_per_seq;
auto num_tokens = meta_data.token_nums;
auto num_heads = meta_data.q_num_heads;
@@ -2027,6 +2433,7 @@ void CascadeAppendWriteCacheKVC8QKV(
dim3 blocks(32, num_warps);
const uint32_t smem_size = (BLOCK_SIZE * HEAD_DIM) * sizeof(T) * 2;
if (cache_quant_type != "block_wise_fp8") {
auto kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
@@ -2034,7 +2441,7 @@ void CascadeAppendWriteCacheKVC8QKV(
BLOCK_SIZE,
num_warps,
true, false>;
if (is_fp8) {
if (cache_quant_type == "cache_fp8") {
kernel_fn = append_write_cache_kv_c8_qkv<T,
num_frags_y,
num_frags_z,
@@ -2070,6 +2477,33 @@ void CascadeAppendWriteCacheKVC8QKV(
max_blocks_per_seq,
num_heads,
kv_num_heads);
} else {
auto kernel_fn = append_write_cache_kv_c8_qkv_dynamic<NV_TYPE,
num_frags_y,
num_frags_z,
HEAD_DIM,
BLOCK_SIZE,
num_warps,
true, true>;
cudaFuncSetAttribute(
kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
kernel_fn<<<grids, blocks, 0, stream>>>(cache_k_out->data<uint8_t>(),
cache_v_out->data<uint8_t>(),
reinterpret_cast<const NV_TYPE*>(qkv.data<T>()),
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_k_scale.data<T>())),
const_cast<NV_TYPE*>(reinterpret_cast<const NV_TYPE*>(cache_v_scale.data<T>())),
batch_ids.data<int>(),
tile_ids_per_batch.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_decoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
block_table.data<int>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
kv_num_heads);
}
}
template <typename T, uint32_t HEAD_DIM, uint32_t BLOCK_SIZE>

View File

@@ -167,7 +167,7 @@ void EncoderWriteCacheWithRopeKernel(
stream,
key_cache_out,
value_cache_out);
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8") {
} else if (cache_quant_type_str == "cache_int8" or cache_quant_type_str == "cache_fp8" or cache_quant_type_str == "block_wise_fp8") {
DISPATCH_HEAD_DIM(
head_dim, HEAD_DIM, {DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, {
CascadeAppendWriteCacheKVC8QKV<T, HEAD_DIM, BLOCK_SIZE>(
@@ -187,7 +187,7 @@ void EncoderWriteCacheWithRopeKernel(
num_blocks,
max_seq_len,
is_scale_channel_wise,
cache_quant_type_str == "cache_fp8",
cache_quant_type_str,
stream,
key_cache_out,
value_cache_out);

View File

@@ -1000,7 +1000,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
stream,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8") {
} else if (cache_quant_type == "cache_int8" || cache_quant_type == "cache_fp8" || cache_quant_type == "block_wise_fp8") {
CascadeAppendWriteCacheKVC8QKV<data_t, 128, 64>(
meta_data,
*const_cast<paddle::Tensor*>(&key_cache),
@@ -1018,7 +1018,7 @@ std::vector<paddle::Tensor> GQARopeWriteCacheKernel(
kv_num_blocks_data,
max_seq_len,
false, // is_scale_channel_wise
cache_quant_type == "cache_fp8", // is_fp8
cache_quant_type,
stream,
const_cast<paddle::Tensor*>(&key_cache),
const_cast<paddle::Tensor*>(&value_cache));

View File

@@ -56,6 +56,7 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -103,5 +104,6 @@ CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::bfloat16, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -98,5 +99,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, paddle::float8_e4
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::bfloat16, int8_t, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, f
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -100,5 +101,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float16, t
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, paddle::float8_e4m
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -54,6 +54,7 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, false>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);
@@ -99,5 +100,6 @@ template void CascadeAppendAttentionC8Kernel<paddle::float16, int8_t, true>(
const bool causal,
const bool is_decoder,
const bool enable_prefill,
const std::string& cache_quant_type_str,
cudaStream_t& stream,
paddle::Tensor* out);

View File

@@ -441,6 +441,15 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
PD_THROW("not support the group_size", group_size); \
}
#define DISPATCH_DyCfp8(is_dynamic_cfp8, IsDynamicC8, ...) \
if (is_dynamic_cfp8) { \
constexpr bool IsDynamicC8 = true; \
__VA_ARGS__ \
} else { \
constexpr bool IsDynamicC8 = false; \
__VA_ARGS__ \
}
#define DISPATCH_MLA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \
if (group_size == 8) { \
constexpr size_t GROUP_SIZE = 8; \

View File

@@ -231,6 +231,17 @@ class AppendAttentionBackend(AttentionBackend):
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index,
)
cache_quant_type_str = getattr(layer, "cache_quant_type_str", "none")
if cache_quant_type_str == "block_wise_fp8":
cache_k = forward_meta.caches[4 * layer.layer_id]
cache_v = forward_meta.caches[4 * layer.layer_id + 1]
cache_k_scales = forward_meta.caches[4 * layer.layer_id + 2]
cache_v_scales = forward_meta.caches[4 * layer.layer_id + 3]
else:
cache_k = forward_meta.caches[2 * layer.layer_id]
cache_v = forward_meta.caches[2 * layer.layer_id + 1]
cache_k_scales = getattr(layer, "cache_k_scale", None)
cache_v_scales = getattr(layer, "cache_v_scale", None)
if self.use_output:
quant_max_bound = getattr(layer, "quant_max_bound", 0.0)
@@ -269,8 +280,8 @@ class AppendAttentionBackend(AttentionBackend):
append_attention_with_output(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
cache_k,
cache_v,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
@@ -293,8 +304,8 @@ class AppendAttentionBackend(AttentionBackend):
metadata.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
cache_k_scales,
cache_v_scales,
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),
@@ -325,8 +336,8 @@ class AppendAttentionBackend(AttentionBackend):
else:
res = append_attention(
qkv,
forward_meta.caches[2 * layer.layer_id],
forward_meta.caches[2 * layer.layer_id + 1],
cache_k,
cache_v,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
@@ -348,8 +359,8 @@ class AppendAttentionBackend(AttentionBackend):
metadata.attn_mask,
layer.qkv_bias,
layer.qkv_scale,
getattr(layer, "cache_k_scale", None),
getattr(layer, "cache_v_scale", None),
cache_k_scales,
cache_v_scales,
getattr(layer, "cache_k_out_scale", None),
getattr(layer, "cache_v_out_scale", None),
getattr(layer, "cache_k_zp", None),

View File

@@ -33,6 +33,7 @@ class KvCacheQuantzationTypes(str, Enum):
INT8 = "int8"
FP8 = "float8_e4m3fn"
BLOCK_WISE_FP8 = "block_wise_fp8"
INT8_ZP = "int8_zp"
INT4_ZP = "int4_zp"
FP8_ZP = "float8_e4m3fn_zp"
@@ -62,7 +63,11 @@ class KvCacheQuantConfig(QuantConfigBase):
if self.quant_type == KvCacheQuantzationTypes.INT8 or self.quant_type == KvCacheQuantzationTypes.INT8_ZP:
self.max_bound = 127.0
elif self.quant_type == KvCacheQuantzationTypes.FP8 or self.quant_type == KvCacheQuantzationTypes.FP8_ZP:
elif (
self.quant_type == KvCacheQuantzationTypes.FP8
or self.quant_type == KvCacheQuantzationTypes.FP8_ZP
or self.quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8
):
self.max_bound = 448.0
elif self.quant_type == KvCacheQuantzationTypes.INT4_ZP:
self.max_bound = 7.0
@@ -178,9 +183,14 @@ class KVCacheMethodBase(QuantMethodBase):
layer.cache_quant_type_str = "cache_int4_zp"
layer.quant_max_bound = 7.0
layer.quant_min_bound = -7.0
elif self.cache_quant_config.quant_type == KvCacheQuantzationTypes.BLOCK_WISE_FP8:
layer.cache_quant_type_str = "block_wise_fp8"
layer.quant_max_bound = 448.0
layer.quant_min_bound = -448.0
else:
raise NotImplementedError(f"{self.cache_quant_config.quant_type} is not implemented")
if "block_wise" not in layer.cache_quant_type_str:
self.load_scale(layer, state_dict)
if self.cache_quant_config.has_zero_point:
self.load_zp(layer, state_dict)

View File

@@ -1023,6 +1023,8 @@ class GPUModelRunner(ModelRunnerBase):
kv_cache_shape = self.attn_backends[0].get_kv_cache_shape(
max_num_blocks=max_block_num, kv_cache_quant_type=kv_cache_quant_type
)
if kv_cache_quant_type == "block_wise_fp8":
kv_cache_scale_shape = [kv_cache_shape[0], kv_cache_shape[1], kv_cache_shape[2]]
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
if not profile and (self.cache_config.enable_prefix_caching or self.parallel_config.splitwise_role != "mixed"):
@@ -1050,6 +1052,17 @@ class GPUModelRunner(ModelRunnerBase):
fill_value=0,
dtype=cache_type,
)
if kv_cache_quant_type == "block_wise_fp8":
cache_kvs[f"key_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
cache_kvs[f"value_cache_scales_{i}"] = paddle.full(
shape=kv_cache_scale_shape,
fill_value=0,
dtype=paddle.get_default_dtype(),
)
self.share_inputs["caches"] = list(cache_kvs.values())
for value in cache_kvs.values():
del value

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import time
import unittest
@@ -20,6 +21,7 @@ import paddle
from paddle.incubate.nn.functional import fused_rms_norm
paddle.seed(10)
np.random.seed(10)
class RopeEmbedding:
@@ -334,7 +336,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.name = "TestAppendGroupQueryAttnWithRope"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 12
self.q_num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
@@ -347,9 +349,10 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "float16"
self.dtype = "bfloat16"
self.use_qk_norm = True
self.use_mask_offset = False
self.use_dynamic_quant = False
self.init_tensor()
def init_tensor(self):
@@ -391,8 +394,23 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
)
self.scale = 1.0 / np.sqrt(self.dim_head)
if self.use_dynamic_quant:
self.cache_scale_shape = (
self.max_block_num,
self.kv_num_head,
self.blocksize,
)
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype="uint8")
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype="uint8")
self.cache_k_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.cache_v_T = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.key_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype)
self.value_cache_scale = paddle.zeros(shape=self.cache_scale_shape, dtype=self.dtype)
else:
self.cache_k = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.cache_v = paddle.zeros(shape=self.cache_shape, dtype=self.dtype)
self.key_cache_scale = None
self.value_cache_scale = None
self.block_tables = paddle.zeros(shape=(self.batch_size, self.block_num_per_seq), dtype="int32")
for i in range(self.batch_size):
need_block_num = (self.seq_len + self.max_dec_len + self.blocksize - 1) // self.blocksize
@@ -415,6 +433,7 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
def cmp_append_attention(self, naive_cache_k=None, naive_cache_v=None, attn_mask=None):
paddle.disable_static()
print("use_dynamic_quant: ", self.use_dynamic_quant)
self.token_num = self.seq_len * self.batch_size
q, k, v, qkv = get_qkv_and_qkv_concat_tensor(
self.batch_size,
@@ -472,18 +491,17 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
self.blocksize,
speculate_max_draft_token_num + 1,
)
if self.use_dynamic_quant:
cache_quant_type = "block_wise_fp8"
else:
cache_quant_type = "none"
# Warm up
WARM_UP = 1
RUN_TIME = 2
for i in range(WARM_UP + RUN_TIME):
if i == WARM_UP:
paddle.device.synchronize()
start_time = time.time()
out = append_attention(
qkv,
self.cache_k,
self.cache_v,
if self.use_dynamic_quant:
qkv_copy = copy.deepcopy(qkv)
append_attention(
qkv_copy,
self.cache_k_T,
self.cache_v_T,
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
@@ -519,7 +537,69 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
k_norm_weight, # k_norm_weight
1e-6,
"fp16",
"none", # cache_quant_type
"none",
self.use_neox_rotary_style,
False,
self.max_seq_len,
0.0, # quant_min_bound
0.0, # quant_max_bound
-1, # out_linear_in_scale
64, # encoder_block_shape_q
16, # decoder_block_shape_q
32768, # max_partition_size
32768, # encoder_max_partition_size
speculate_max_draft_token_num + 1, # speculate_max_draft_token_num
True, # causal
False, # speculate_decoder
)
# Warm up
WARM_UP = 1
RUN_TIME = 2
for i in range(WARM_UP + RUN_TIME):
if i == WARM_UP:
paddle.device.synchronize()
start_time = time.time()
out = append_attention(
qkv,
self.cache_k,
self.cache_v,
self.seq_lens_encoder,
self.seq_lens_decoder,
self.seq_lens_this_time,
self.padding_offset,
self.cum_offset,
self.block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks,
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
self.decoder_batch_ids,
self.decoder_tile_ids_per_batch,
self.decoder_num_blocks_cpu,
self.max_len_tensor_cpu,
max_len_kv,
self.rope_emb, # rope_emb
None, # attn_mask
None, # qkv_bias
None, # qkv_out_scales
self.key_cache_scale, # cache_k_quant_scales
self.value_cache_scale, # cache_v_quant_scales
None, # cache_k_dequant_scales
None, # cache_v_dequant_scales
None, # cache_k_zp
None, # cache_v_zp
None, # linear_shift
None, # linear_smooth
self.mask_offset, # mask_offset
None, # kv_signal_data
q_norm_weight, # q_norm_weight
k_norm_weight, # k_norm_weight
1e-6,
"fp16",
cache_quant_type,
self.use_neox_rotary_style,
False,
self.max_seq_len,
@@ -537,13 +617,6 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
paddle.device.synchronize()
end_time = time.time()
print(f"[append-attn ut] cost_time:{(end_time - start_time) / RUN_TIME * 1000}ms")
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k,
self.cache_v,
self.batch_size,
self.block_tables,
self.seq_len,
)
np.testing.assert_allclose(
out.numpy(),
out_.numpy(),
@@ -572,6 +645,15 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
if self.use_mask_offset:
print("encoder mask_offset: ", self.mask_offset)
self.cmp_append_attention(attn_mask=self.attention_mask)
if self.use_dynamic_quant:
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k_T,
self.cache_v_T,
self.batch_size,
self.block_tables,
self.seq_len,
)
else:
naive_cache_k, naive_cache_v = block_cache_to_naive_cache(
self.cache_k,
self.cache_v,
@@ -613,10 +695,10 @@ class TestAppendGroupQueryAttnWithRope(unittest.TestCase):
class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
def setUp(self):
paddle.disable_static()
self.name = "TestAppendGroupQueryAttnWithRope"
self.name = "TestAppendGroupQueryAttnWithNeoXRope"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 12
self.q_num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
@@ -632,6 +714,33 @@ class TestAppendGroupQueryAttnWithNeoXRope(TestAppendGroupQueryAttnWithRope):
self.dtype = "float16"
self.use_qk_norm = False
self.use_mask_offset = True
self.use_dynamic_quant = False
self.init_tensor()
class TestAppendGroupQueryAttnWithRopeDyCfp8(TestAppendGroupQueryAttnWithRope):
def setUp(self):
paddle.disable_static()
self.name = "TestAppendGroupQueryAttnWithRopeDyCfp8"
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.q_num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.max_dec_len = 64
self.dim_head = 128
self.q_hid_dim = self.q_num_head * self.dim_head
self.kv_hid_dim = self.kv_num_head * self.dim_head
self.blocksize = 64
self.use_neox_rotary_style = False
# max_seq_len = self.seq_len + self.max_dec_len
self.max_seq_len = self.seq_len + self.max_dec_len
self.softmax_scale = self.dim_head**-0.5
self.rope_theta = 10000
self.dtype = "bfloat16"
self.use_qk_norm = True
self.use_mask_offset = False
self.use_dynamic_quant = True
self.init_tensor()