fix conflicts

This commit is contained in:
carryyu
2025-11-13 18:17:44 +08:00
committed by lizhenyun01
parent ae7bee8122
commit 6c3d1da62f
2 changed files with 381 additions and 301 deletions

View File

@@ -383,56 +383,45 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
}
}
template <uint32_t block_size,
template <SharedMemFillMode fill_mode,
uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale(
T* k_smem_scale,
T* cache_k_reg,
__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async(
smem_t kv_scale_smem,
const int* block_table_now,
const T* cache_k_scale,
const T* cache_kv_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
const uint32_t tid = ty * 32 + tx;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
if (tid < block_size) {
k_smem_scale[tid] = cache_k_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
if (tid < block_size / 8) {
const T* cache_k_scale_now = cache_kv_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size + tid * 8;
const int kv_idx_this_thread = kv_idx + tid * 8;
kv_scale_smem.load_128b_async<fill_mode>(
tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
} else {
k_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
if (tid < block_size / 8 * 2) {
const uint32_t kv_idx_now = kv_idx + block_size * tid / 8;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const int kv_idx_this_thread = kv_idx + tid * 8;
const T* cache_k_scale_now = cache_kv_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size + tid % 8 * 8;
kv_scale_smem.load_128b_async<fill_mode>(
tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
}
}
}
@@ -441,57 +430,55 @@ template <uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale(
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end) {
__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg(
T* k_smem_scale, T* cache_k_reg) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
const uint32_t scale_idx = fz * 16 + row_id;
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
}
} else {
// 1 warp 32 tokens
const uint32_t row_id = tx / 4;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
}
}
}
template <uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg(
T* v_smem_scale, T* cache_v_reg) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
if (tid < block_size) {
v_smem_scale[tid] = cache_v_scale_now[tid];
}
__syncthreads();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
const uint32_t scale_idx = fz * 16 + row_id;
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
}
} else {
// 1 warp 32 tokens
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
} else {
v_smem_scale[ty * 32 + tx] = 0;
}
__syncwarp();
const uint32_t row_id = tx % 4 * 2;
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
}
}
}