mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
fix conflicts
This commit is contained in:
@@ -383,56 +383,45 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t block_size,
|
||||
template <SharedMemFillMode fill_mode,
|
||||
uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale(
|
||||
T* k_smem_scale,
|
||||
T* cache_k_reg,
|
||||
__device__ __forceinline__ void produce_kv_dynamic_scale_gmem2smem_async(
|
||||
smem_t kv_scale_smem,
|
||||
const int* block_table_now,
|
||||
const T* cache_k_scale,
|
||||
const T* cache_kv_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
k_smem_scale[tid] = cache_k_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[fz * 16 + row_id + 8];
|
||||
if (tid < block_size / 8) {
|
||||
const T* cache_k_scale_now = cache_kv_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size + tid * 8;
|
||||
const int kv_idx_this_thread = kv_idx + tid * 8;
|
||||
kv_scale_smem.load_128b_async<fill_mode>(
|
||||
tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_k_scale_now = cache_k_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
k_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_k_reg[fz * 2] = k_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
if (tid < block_size / 8 * 2) {
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * tid / 8;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const int kv_idx_this_thread = kv_idx + tid * 8;
|
||||
const T* cache_k_scale_now = cache_kv_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size + tid % 8 * 8;
|
||||
kv_scale_smem.load_128b_async<fill_mode>(
|
||||
tid, cache_k_scale_now, kv_idx_this_thread < chunk_end);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -441,57 +430,55 @@ template <uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale(
|
||||
T* v_smem_scale,
|
||||
T* cache_v_reg,
|
||||
const int* block_table_now,
|
||||
const T* cache_v_scale,
|
||||
const uint32_t kv_idx,
|
||||
const uint32_t kv_num_heads,
|
||||
const uint32_t kv_head_idx,
|
||||
const uint32_t chunk_end) {
|
||||
__device__ __forceinline__ void produce_k_dynamic_scale_smem2reg(
|
||||
T* k_smem_scale, T* cache_k_reg) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = fz * 16 + row_id;
|
||||
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t row_id = tx / 4;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
|
||||
cache_k_reg[fz * 2] = k_smem_scale[scale_idx];
|
||||
cache_k_reg[fz * 2 + 1] = k_smem_scale[scale_idx + 8];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <uint32_t block_size,
|
||||
uint32_t num_frags_z,
|
||||
uint32_t NUM_WARP_Q,
|
||||
typename T>
|
||||
__device__ __forceinline__ void produce_v_dynamic_scale_smem2reg(
|
||||
T* v_smem_scale, T* cache_v_reg) {
|
||||
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
|
||||
|
||||
if constexpr (NUM_WARP_Q == 4) {
|
||||
// 4 warps shared block_size
|
||||
const uint32_t tid = ty * 32 + tx;
|
||||
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
if (tid < block_size) {
|
||||
v_smem_scale[tid] = cache_v_scale_now[tid];
|
||||
}
|
||||
__syncthreads();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[fz * 16 + row_id + 9];
|
||||
const uint32_t scale_idx = fz * 16 + row_id;
|
||||
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
|
||||
}
|
||||
} else {
|
||||
// 1 warp 32 tokens
|
||||
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
|
||||
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
|
||||
if (block_id < 0) block_id = 0;
|
||||
const T* cache_v_scale_now = cache_v_scale +
|
||||
block_id * kv_num_heads * block_size +
|
||||
kv_head_idx * block_size;
|
||||
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
|
||||
if (kv_idx_this_thread < chunk_end) {
|
||||
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
|
||||
} else {
|
||||
v_smem_scale[ty * 32 + tx] = 0;
|
||||
}
|
||||
__syncwarp();
|
||||
const uint32_t row_id = tx % 4 * 2;
|
||||
for (uint32_t fz = 0; fz < num_frags_z; fz++) {
|
||||
cache_v_reg[fz * 4] = v_smem_scale[ty * 32 + fz * 16 + row_id];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[ty * 32 + fz * 16 + row_id + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[ty * 32 + fz * 16 + row_id + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[ty * 32 + fz * 16 + row_id + 9];
|
||||
const uint32_t scale_idx = ty * 32 + fz * 16 + row_id;
|
||||
cache_v_reg[fz * 4] = v_smem_scale[scale_idx];
|
||||
cache_v_reg[fz * 4 + 1] = v_smem_scale[scale_idx + 1];
|
||||
cache_v_reg[fz * 4 + 2] = v_smem_scale[scale_idx + 8];
|
||||
cache_v_reg[fz * 4 + 3] = v_smem_scale[scale_idx + 9];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user