fix conflicts

This commit is contained in:
carryyu
2025-11-13 18:17:44 +08:00
committed by lizhenyun01
parent ae7bee8122
commit 6c3d1da62f
2 changed files with 381 additions and 301 deletions

View File

@@ -383,56 +383,45 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
}
}
template <uint32_t block_size,
template <SharedMemFillMode fill_mode,
uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale(
T* k_smem_scale,
T* cache_k_reg,
__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async(
smem_t kv_scale_smem,
const int* block_table_now,
const T* cache_k_scale,
const T* cache_kv_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
const uint32_t tid = ty * 32 + tx;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
if (tid < block_size) {
k_smem_scale[tid] = cache_k_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
if (tid < block_size / 8) {
const T* cache_k_scale_now = cache_kv_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size + tid * 8;
const int kv_idx_this_thread = kv_idx + tid * 8;
kv_scale_smem.load_128b_async<fill_mode>(
tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
} else {
k_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
if (tid < block_size / 8 * 2) {
const uint32_t kv_idx_now = kv_idx + block_size * tid / 8;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const int kv_idx_this_thread = kv_idx + tid * 8;
const T* cache_k_scale_now = cache_kv_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size + tid % 8 * 8;
kv_scale_smem.load_128b_async<fill_mode>(
tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
}
}
}
@@ -441,57 +430,55 @@ template <uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale(
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end) {
__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg(
T* k_smem_scale, T* cache_k_reg) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
const uint32_t scale_idx = fz * 16 + row_id;
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
}
} else {
// 1 warp 32 tokens
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
}
}
}
template <uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg(
T* v_smem_scale, T* cache_v_reg) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
if (tid < block_size) {
v_smem_scale[tid] = cache_v_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
const uint32_t scale_idx = fz * 16 + row_id;
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
} else {
v_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
}
}
}

View File

@@ -34,15 +34,17 @@ template <typename T,
bool IsFP8 = false,
bool IsDynamicC8 = false>
__global__ void multi_query_append_attention_c8_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num,
// num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num,
// num_heads, block_size]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
@@ -89,8 +91,8 @@ __global__ void multi_query_append_attention_c8_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
// When cudagraph capture prefill, may launch more gridDim.x
if (btid >= static_cast<uint32_t>(num_blocks_x_cpu)) {
return;
}
@@ -190,7 +192,8 @@ __global__ void multi_query_append_attention_c8_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq =
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -211,15 +214,19 @@ __global__ void multi_query_append_attention_c8_kernel(
smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
T* k_smem_scale = nullptr;
T* v_smem_scale = nullptr;
T *k_smem_scale_ptr = nullptr;
T *v_smem_scale_ptr = nullptr;
smem_t k_scale_smem;
smem_t v_scale_smem;
if constexpr (IsDynamicC8) {
k_smem_scale = reinterpret_cast<T*>(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale = k_smem_scale + num_frags_z * 16;
k_smem_scale_ptr = reinterpret_cast<T *>(
smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale_ptr = k_smem_scale_ptr + num_frags_z * 16;
k_scale_smem.base = reinterpret_cast<b128_t *>(k_smem_scale_ptr);
v_scale_smem.base = reinterpret_cast<b128_t *>(v_smem_scale_ptr);
}
const uint32_t num_iterations = div_up(
CAUSAL
? (min(chunk_len,
@@ -230,12 +237,13 @@ __global__ void multi_query_append_attention_c8_kernel(
: chunk_len,
num_frags_z * 16);
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: mask_offset ? 0
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -248,8 +256,7 @@ __global__ void multi_query_append_attention_c8_kernel(
uint32_t k_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
wid * 4 + tid / 8,
tid % 8);
wid * 4 + tid / 8, tid % 8);
uint32_t v_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
wid * 8 + tid / 4, tid % 4); // 4 * 128 / 8 = 64
@@ -278,6 +285,18 @@ __global__ void multi_query_append_attention_c8_kernel(
kv_idx_base,
chunk_end,
const_k_offset);
if constexpr (IsDynamicC8) {
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
BLOCK_SIZE,
num_frags_z,
NUM_WARP_Q>(k_scale_smem,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end);
}
commit_group();
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -295,32 +314,42 @@ __global__ void multi_query_append_attention_c8_kernel(
kv_idx_base,
chunk_end,
const_v_offset);
if constexpr (IsDynamicC8) {
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
BLOCK_SIZE,
num_frags_z,
NUM_WARP_Q>(v_scale_smem,
block_table_now,
cache_v_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end);
}
commit_group();
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale,
cache_k_scale_reg,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale_ptr, cache_k_scale_reg);
}
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
&k_smem_offset_r,
cache_k_scale_reg,
s_frag);
compute_qk_c8<num_frags_x,
num_frags_y,
num_frags_z,
T,
CacheT,
is_scale_channel_wise,
IsFP8,
IsDynamicC8>(&qo_smem,
&q_smem_offset_r,
&k_smem,
&k_smem_offset_r,
cache_k_scale_reg,
s_frag);
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration || sliding_window > 0) {
@@ -366,21 +395,25 @@ __global__ void multi_query_append_attention_c8_kernel(
kv_idx_base,
chunk_end,
const_k_offset);
commit_group();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale,
cache_v_scale_reg,
block_table_now,
cache_v_scale,
ori_kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
BLOCK_SIZE,
num_frags_z,
NUM_WARP_Q>(k_scale_smem,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end);
}
commit_group();
wait_group<1>();
__syncthreads();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale_ptr, cache_v_scale_reg);
}
// compute sfm*v
compute_sfm_v_c8<num_frags_x,
@@ -411,8 +444,19 @@ __global__ void multi_query_append_attention_c8_kernel(
kv_idx_base,
chunk_end,
const_v_offset);
if constexpr (IsDynamicC8) {
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
BLOCK_SIZE,
num_frags_z,
NUM_WARP_Q>(v_scale_smem,
block_table_now,
cache_v_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end);
}
commit_group();
}
wait_group<0>();
__syncthreads();
@@ -420,15 +464,19 @@ __global__ void multi_query_append_attention_c8_kernel(
if constexpr (!partition_kv) {
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
const uint32_t h_offset =
(q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) %
GROUP_SIZE;
current_sinks[fx][j] =
static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
normalize_d<num_frags_x, num_frags_y>(
o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
@@ -474,7 +522,6 @@ __global__ void multi_query_append_attention_c8_kernel(
HEAD_DIM);
}
if constexpr (partition_kv) {
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -520,16 +567,18 @@ template <typename T,
uint32_t num_frags_y,
typename OutT = T,
bool ENABLE_PREFILL = true,
bool is_scale_channel_wise=false,
bool is_scale_channel_wise = false,
bool IsFP8 = false,
bool IsDynamicC8 = false>
__global__ void multi_query_append_attention_c8_warp1_4_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num, num_heads, block_size]
const T *__restrict__ cache_k_scale, // [num_kv_heads] or [max_block_num,
// num_heads, block_size]
const T *__restrict__ cache_v_scale, // [num_kv_heads] or [max_block_num,
// num_heads, block_size]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
@@ -540,7 +589,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -580,8 +629,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
// When cudagraph capture prefill, may launch more gridDim.x
if (btid >= static_cast<uint32_t>(num_blocks_x_cpu)) {
return;
}
@@ -678,7 +727,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq =
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -703,12 +753,17 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)),
v_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT));
T* k_smem_scale = nullptr;
T* v_smem_scale = nullptr;
T *k_smem_scale_ptr = nullptr;
T *v_smem_scale_ptr = nullptr;
smem_t k_scale_smem;
smem_t v_scale_smem;
if constexpr (IsDynamicC8) {
k_smem_scale = reinterpret_cast<T*>(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale = k_smem_scale + NUM_WARP_KV * num_frags_z * 16;
k_smem_scale_ptr = reinterpret_cast<T *>(
smem + num_frags_x * 16 * HEAD_DIM * sizeof(T) +
NUM_WARP_KV * num_frags_z * 16 * HEAD_DIM * sizeof(CacheT) * 2);
v_smem_scale_ptr = k_smem_scale_ptr + NUM_WARP_KV * num_frags_z * 16;
k_scale_smem.base = reinterpret_cast<b128_t *>(k_smem_scale_ptr);
v_scale_smem.base = reinterpret_cast<b128_t *>(v_smem_scale_ptr);
}
const uint32_t num_iterations = div_up(
@@ -721,12 +776,13 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
: chunk_len,
NUM_WARP_KV * num_frags_z * 16);
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: mask_offset ? 0
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -740,9 +796,7 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
uint32_t k_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
wid * 4 + tid / 8,
tid %
8);
wid * 4 + tid / 8, tid % 8);
uint32_t v_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
wid * 8 + tid / 4, tid % 4);
@@ -772,6 +826,18 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
kv_idx_base,
chunk_end,
const_k_offset);
if constexpr (IsDynamicC8) {
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
BLOCK_SIZE,
num_frags_z,
NUM_WARP_Q>(k_scale_smem,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end);
}
commit_group();
produce_v_blockwise_c8<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -789,32 +855,42 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
kv_idx_base,
chunk_end,
const_v_offset);
if constexpr (IsDynamicC8) {
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
BLOCK_SIZE,
num_frags_z,
NUM_WARP_Q>(v_scale_smem,
block_table_now,
cache_v_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end);
}
commit_group();
#pragma unroll 1
for (uint32_t iter = 0; iter < num_iterations; ++iter) {
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale,
cache_k_scale_reg,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
}
wait_group<1>();
__syncthreads();
if constexpr (IsDynamicC8) {
produce_k_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
k_smem_scale_ptr, cache_k_scale_reg);
}
// s = qk
compute_qk_c8<num_frags_x, num_frags_y, num_frags_z, T, CacheT, is_scale_channel_wise, IsFP8, IsDynamicC8>(
&qo_smem,
&q_smem_offset_r,
&k_smem,
&k_smem_offset_r,
cache_k_scale_reg,
s_frag);
compute_qk_c8<num_frags_x,
num_frags_y,
num_frags_z,
T,
CacheT,
is_scale_channel_wise,
IsFP8,
IsDynamicC8>(&qo_smem,
&q_smem_offset_r,
&k_smem,
&k_smem_offset_r,
cache_k_scale_reg,
s_frag);
// mask according to kv_idx and q_idx
if (iter >= mask_check_iteration || sliding_window > 0) {
mask_s<T,
@@ -824,17 +900,18 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq,
sliding_window);
num_frags_z>(
attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len
: nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq,
sliding_window);
}
// update m,d
@@ -860,21 +937,25 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
kv_idx_base,
chunk_end,
const_k_offset);
commit_group();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale,
cache_v_scale_reg,
block_table_now,
cache_v_scale,
ori_kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end
);
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
BLOCK_SIZE,
num_frags_z,
NUM_WARP_Q>(k_scale_smem,
block_table_now,
cache_k_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end);
}
commit_group();
wait_group<1>();
__syncthreads();
if constexpr (IsDynamicC8) {
produce_v_dynamic_scale_smem2reg<BLOCK_SIZE, num_frags_z, NUM_WARP_Q, T>(
v_smem_scale_ptr, cache_v_scale_reg);
}
// compute sfm * v
compute_sfm_v_c8_iter_sq_bvec<num_frags_x,
@@ -905,6 +986,18 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
kv_idx_base,
chunk_end,
const_v_offset);
if constexpr (IsDynamicC8) {
produce_kv_dynamic_scale_gmem2smem_async<SharedMemFillMode::kFillZero,
BLOCK_SIZE,
num_frags_z,
NUM_WARP_Q>(v_scale_smem,
block_table_now,
cache_v_scale,
kv_idx_base,
kv_num_heads,
kv_head_idx,
chunk_end);
}
commit_group();
}
wait_group<0>();
@@ -916,15 +1009,19 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
if (num_chunks_this_seq <= 1) {
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
const uint32_t h_offset =
(q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) %
GROUP_SIZE;
current_sinks[fx][j] =
static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
normalize_d<num_frags_x, num_frags_y>(
o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
@@ -981,7 +1078,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
if (qo_idx - q_start_seq_id < q_len) {
uint32_t offset;
if (ENABLE_PREFILL) {
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
@@ -1095,25 +1191,24 @@ void MultiQueryAppendC8Attention(
IsFP8,
IsDynamicC8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
uint8_t,
true,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
split_kv_kernel = multi_query_append_attention_c8_kernel<NV_TYPE,
uint8_t,
true,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
@@ -1152,24 +1247,24 @@ void MultiQueryAppendC8Attention(
IsDynamicC8>;
if (is_scale_channel_wise) {
nosplit_kv_kernel =
multi_query_append_attention_c8_kernel<NV_TYPE,
uint8_t,
false,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
multi_query_append_attention_c8_kernel<NV_TYPE,
uint8_t,
false,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
@@ -1190,8 +1285,8 @@ void MultiQueryAppendC8Attention(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1253,8 +1348,8 @@ void MultiQueryAppendC8Attention(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1305,8 +1400,8 @@ void MultiQueryAppendC8Attention(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1319,8 +1414,7 @@ void MultiQueryAppendC8Attention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
@@ -1344,8 +1438,8 @@ void MultiQueryAppendC8Attention(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1386,24 +1480,24 @@ void MultiQueryAppendC8Attention(
IsDynamicC8>;
if (is_scale_channel_wise) {
split_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
true,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
true,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
@@ -1421,9 +1515,9 @@ void MultiQueryAppendC8Attention(
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
attn_mask_len = -1;
}
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
@@ -1450,24 +1544,24 @@ void MultiQueryAppendC8Attention(
IsDynamicC8>;
if (is_scale_channel_wise) {
nosplit_kv_kernel =
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
false,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
multi_query_append_attention_c8_warp1_4_kernel<NV_TYPE,
uint8_t,
false,
GROUP_SIZE,
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL,
true,
IsFP8,
IsDynamicC8>;
}
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
@@ -1488,8 +1582,8 @@ void MultiQueryAppendC8Attention(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1498,7 +1592,7 @@ void MultiQueryAppendC8Attention(
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1561,14 +1655,14 @@ void MultiQueryAppendC8Attention(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1577,7 +1671,7 @@ void MultiQueryAppendC8Attention(
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1611,14 +1705,14 @@ void MultiQueryAppendC8Attention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1631,15 +1725,14 @@ void MultiQueryAppendC8Attention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1650,14 +1743,14 @@ void MultiQueryAppendC8Attention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,