format code (#4720)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
周周周
2025-11-01 19:13:50 +08:00
committed by GitHub
parent 4ac6de9a3c
commit 6e01be28e0
3 changed files with 379 additions and 379 deletions

View File

@@ -142,7 +142,6 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps(
const uint32_t tx_offset = tx / 8;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset;
#pragma unroll
const int j = ty;
@@ -151,8 +150,7 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps(
const uint32_t h_offset = offset_now % group_size;
T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
q_smem_offset_w =
@@ -171,7 +169,7 @@ template <uint32_t group_size,
uint32_t HEAD_DIM,
typename T>
__device__ __forceinline__ void load_q_global_smem(
T* q_ptr_base,
const T* q_ptr_base,
smem_t* q_smem,
uint32_t q_idx_base,
const uint32_t qo_upper_bound,
@@ -194,10 +192,10 @@ __device__ __forceinline__ void load_q_global_smem(
const uint32_t offset_now = base_offset + j * 4;
const uint32_t n_offset = offset_now / group_size;
const uint32_t h_offset = offset_now % group_size;
T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
const T* q_ptr =
q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
q_smem_offset_w =
@@ -223,8 +221,7 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps(
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
#pragma unroll
for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024;
++i) {
for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) {
const int offset = i * 1024 + ty * 256 + tx * 8;
Load<T, vec_size>(reinterpret_cast<T*>(q_smem->base) + offset, &tmp_vec);
#pragma unroll
@@ -289,11 +286,9 @@ __device__ __forceinline__ void produce_kv_blockwise(
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check
#pragma unroll
for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps;
++i) {
for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) {
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4;
++j) {
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
*gptr += 8 * num_elems_per_128b<T>();
@@ -332,9 +327,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
block_size / num_elems_per_128b<CacheT>(); // 8
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t kv_idx =
kv_idx_base +
tx % 4 * num_elems_per_128b<CacheT>();
uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b<CacheT>();
if constexpr (NUM_WARP_Q == 4) {
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
@@ -343,8 +336,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
for (uint32_t i = 0; i < num_frags_y * 2 / num_warps;
++i) { // m (num_frags_y * 16 / (num_warps * 8))
#pragma unroll
for (uint32_t j = 0; j < num_frags_z / 4;
++j) {
for (uint32_t j = 0; j < num_frags_z / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
*smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>(
*smem_offset, j);
@@ -369,8 +361,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
for (uint32_t i = 0; i < num_frags_y * 2 / num_warps;
++i) { // m (num_frags_y * 16 / (num_warps * 8))
#pragma unroll
for (uint32_t j = 0; j < 2 * num_frags_z / 4;
++j) {
for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
*smem_offset =
smem.advance_offset_by_column<4, num_vecs_per_blocksize>(
@@ -392,27 +383,28 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
template <uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale(
T* k_smem_scale,
T* cache_k_reg,
const int* block_table_now,
const T* cache_k_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
T* k_smem_scale,
T* cache_k_reg,
const int* block_table_now,
const T* cache_k_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const T* cache_k_scale_now = cache_k_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
if (tid < block_size) {
k_smem_scale[tid] = cache_k_scale_now[tid];
}
@@ -427,10 +419,12 @@ __device__ __forceinline__ void produce_k_dynamic_scale(
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const T* cache_k_scale_now = cache_k_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
} else {
k_smem_scale[ty * 32 + tx] = 0;
}
@@ -443,20 +437,19 @@ __device__ __forceinline__ void produce_k_dynamic_scale(
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
template <uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale(
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
@@ -464,7 +457,9 @@ __device__ __forceinline__ void produce_v_dynamic_scale(
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const T* cache_v_scale_now = cache_v_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
if (tid < block_size) {
v_smem_scale[tid] = cache_v_scale_now[tid];
}
@@ -481,10 +476,12 @@ __device__ __forceinline__ void produce_v_dynamic_scale(
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const T* cache_v_scale_now = cache_v_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
} else {
v_smem_scale[ty * 32 + tx] = 0;
}
@@ -560,8 +557,7 @@ __device__ __forceinline__ void produce_k_blockwise_c8(
for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps;
++i) { // m num_frags_z * 16 / (num_warps * 4)
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 8;
++j) {
for (uint32_t j = 0; j < num_frags_y / 8; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
*smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>(
*smem_offset, j);
@@ -614,8 +610,7 @@ __device__ __forceinline__ void produce_v_blockwise_c4(
#pragma unroll
for (uint32_t i = 0; i < num_frags_y / num_warps; ++i) { // m
#pragma unroll
for (uint32_t j = 0; j < num_frags_z / 4;
++j) {
for (uint32_t j = 0; j < num_frags_z / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
*smem_offset = smem.advance_offset_by_column<2, num_vecs_per_blocksize>(
*smem_offset, j);
@@ -671,8 +666,7 @@ __device__ __forceinline__ void produce_k_blockwise_c4(
for (uint32_t i = 0; i < num_frags_z * 2 / num_warps;
++i) { // m num_frags_z * 16 / (num_warps * 8)
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 8;
++j) {
for (uint32_t j = 0; j < num_frags_y / 8; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
*smem_offset = smem.advance_offset_by_column<4, num_vecs_per_head>(
*smem_offset, j);
@@ -937,7 +931,7 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
uint32_t* q_smem_offset_r,
smem_t* k_smem,
uint32_t* k_smem_offset_r,
const T *cache_k_scale,
const T* cache_k_scale,
float (*s_frag)[num_frags_z][8]) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b<T>();
@@ -973,8 +967,8 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
#pragma unroll
for (uint32_t fy = 0; fy < 2; ++fy) {
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
@@ -1036,7 +1030,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t chunk_end,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr,
const int* mask_offset = nullptr,
const int sliding_window = 0) {
const uint32_t tx = threadIdx.x;
#pragma unroll
@@ -1053,24 +1047,25 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
8 * (reg_id / 4) + reg_id % 2;
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
}
else if (sliding_window > 0)
{
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window;
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
}
else
{
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
out_of_boundary = q_idx < qo_len
? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
kv_idx < mask_offset[q_idx * 2])
: true;
} else if (sliding_window > 0) {
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx -
(int)qo_len -
sliding_window;
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
out_of_window || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
} else {
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
(kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len &&
kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx =
q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
out_of_boundary |= mask;
}
@@ -1236,7 +1231,7 @@ __device__ __forceinline__ void compute_sfm_v_c8(
float (*s_frag)[num_frags_z][8],
float (*o_frag)[num_frags_y][8],
float (*d)[2],
const T *cache_v_scale) {
const T* cache_v_scale) {
constexpr uint32_t num_vecs_per_blocksize =
block_size / num_elems_per_128b<CacheT>();
T s_frag_f16[num_frags_x][num_frags_z][8];
@@ -1268,8 +1263,8 @@ __device__ __forceinline__ void compute_sfm_v_c8(
#pragma unroll
for (uint32_t fz = 0; fz < 2; ++fz) {
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
@@ -1300,7 +1295,6 @@ __device__ __forceinline__ void compute_sfm_v_c8(
o_frag[fx][fy],
(uint32_t*)(s_frag_f16[fx][kz * 2 + fz]),
b_frag_dq);
}
}
}
@@ -1328,7 +1322,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
float (*s_frag)[num_frags_z][8],
float (*o_frag)[num_frags_y][8],
float (*d)[2],
T *cache_v_scale) {
T* cache_v_scale) {
constexpr uint32_t num_vecs_per_blocksize =
block_size / num_elems_per_128b<CacheT>();
@@ -1362,8 +1356,8 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
for (uint32_t fz = 0; fz < 2; ++fz) {
// dequant b_frag -> b_frag_dq
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
@@ -1372,7 +1366,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
#pragma unroll
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
@@ -1431,8 +1425,7 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem,
}
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z;
++fz) {
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t b_frag[4];
@@ -1611,10 +1604,9 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
fx * 16 + tx / 4,
fy * 2);
uint32_t o_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head>(fx * 16 + tx / 4,
fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
@@ -1627,8 +1619,8 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
}
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * 4 + tx / 8, tx % 8);
uint32_t o_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head>(ty * 4 + tx / 8, tx % 8);
o_idx_base += (tx / 8) / group_size;
o_ptr_base += ((tx / 8) / group_size) * qo_n_stride +
@@ -1642,8 +1634,7 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride +
((fx * 16 + j * 4) % group_size) * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
if (o_idx < qo_upper_bound) {
// need write
o_smem->store_128b(o_smem_offset_w, o_ptr);
@@ -1658,7 +1649,6 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
}
}
template <typename T, int VEC_SIZE, typename OutT>
struct StoreFunc {
__device__ __forceinline__ void operator()(
@@ -1717,7 +1707,6 @@ struct StoreFunc<T, VEC_SIZE, __nv_fp8_e4m3> {
}
};
template <typename T, int VEC_SIZE>
struct StoreFunc<T, VEC_SIZE, T> {
__device__ __forceinline__ void operator()(
@@ -1770,10 +1759,9 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
fx * 16 + tx / 4,
fy * 2);
uint32_t o_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head>(fx * 16 + tx / 4,
fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
@@ -1786,8 +1774,8 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
}
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * 4 + tx / 8, tx % 8);
uint32_t o_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head>(ty * 4 + tx / 8, tx % 8);
const uint32_t tx_offset = tx / 8;
#pragma unroll
@@ -1804,8 +1792,7 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
tx % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
if (n_offset < qo_upper_bound) {
if constexpr (!partition_kv) {
Load<T, VEC_SIZE>(
@@ -1881,10 +1868,8 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
(ty * num_frags_x + fx) * 16 + tx / 4,
fy * 2);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
(ty * num_frags_x + fx) * 16 + tx / 4, fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
@@ -1897,8 +1882,7 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * num_frags_x * 16 + tx / 8,
tx % 8);
ty * num_frags_x * 16 + tx / 8, tx % 8);
const uint32_t tx_offset = tx / 8;
#pragma unroll
@@ -1914,13 +1898,12 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
tx % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
if (n_offset < qo_upper_bound) {
if (!partition_kv) {
Load<T, VEC_SIZE>(
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
&ori_out_vec);
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
&ori_out_vec);
if (in_scale > 0.0) {
if (shift_bias) {
Load<T, VEC_SIZE>(shift_bias + shift_smooth_offset,
@@ -1929,16 +1912,16 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
&smooth_weight_vec);
}
}
#pragma unroll
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
StoreFunc<T, VEC_SIZE, OutT>()(ori_out_vec,
shift_bias_vec,
smooth_weight_vec,
out_vec,
quant_max_bound,
quant_min_bound,
in_scale,
i);
shift_bias_vec,
smooth_weight_vec,
out_vec,
quant_max_bound,
quant_min_bound,
in_scale,
i);
}
Store<OutT, VEC_SIZE>(out_vec, o_ptr);
} else {
@@ -1979,10 +1962,8 @@ __device__ __forceinline__ void write_o_reg_gmem(
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
(ty * num_frags_x + fx) * 16 + tx / 4,
fy * 2);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
(ty * num_frags_x + fx) * 16 + tx / 4, fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
@@ -1995,8 +1976,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * num_frags_x * 16 + tx / 8,
tx % 8);
ty * num_frags_x * 16 + tx / 8, tx % 8);
o_idx_base += (tx / 8) / group_size;
o_ptr_base += ((tx / 8) / group_size) * qo_n_stride +
@@ -2009,8 +1989,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride +
((fx * 16 + j * 4) % group_size) * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
if (o_idx < qo_upper_bound) {
o_smem->store_128b(o_smem_offset_w, o_ptr);
}
@@ -2125,7 +2104,6 @@ __global__ void merge_multi_chunks_kernel(
&out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
}
template <uint32_t num_frags_x, uint32_t num_frags_y, typename T>
__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8],
float* md_smem,
@@ -2307,18 +2285,18 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true>
__global__ void merge_multi_chunks_decoder_kernel(
const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads,
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
// head_dim]
const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
OutT *__restrict__ out,
const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int* __restrict__ seq_lens_q,
const int* __restrict__ seq_lens_kv,
const int* __restrict__ seq_lens_encoder,
const int* __restrict__ cu_seqlens_q,
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T* __restrict__ sinks, // [q_num_heads]
OutT* __restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -2419,8 +2397,14 @@ __global__ void merge_multi_chunks_decoder_kernel(
}
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
StoreFunc<T, vec_size, OutT>()(
st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i);
StoreFunc<T, vec_size, OutT>()(st.o,
shift_bias_vec,
smooth_weight_vec,
out_vec,
quant_max_bound,
quant_min_bound,
in_scale,
i);
}
Store<OutT, vec_size>(
out_vec,
@@ -2435,19 +2419,19 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true>
__global__ void merge_multi_chunks_v2_kernel(
const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads,
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
// head_dim]
const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ batch_id_per_token,
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
OutT *__restrict__ out,
const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int* __restrict__ seq_lens_q,
const int* __restrict__ seq_lens_kv,
const int* __restrict__ seq_lens_encoder,
const int* __restrict__ batch_id_per_token,
const int* __restrict__ cu_seqlens_q,
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T* __restrict__ sinks, // [q_num_heads]
OutT* __restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -2464,7 +2448,7 @@ __global__ void merge_multi_chunks_v2_kernel(
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
if(bid == -1){
if (bid == -1) {
continue;
}
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
@@ -2486,7 +2470,7 @@ __global__ void merge_multi_chunks_v2_kernel(
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
continue;
}else if (!ENABLE_PREFILL){
} else if (!ENABLE_PREFILL) {
continue;
}
@@ -2496,12 +2480,12 @@ __global__ void merge_multi_chunks_v2_kernel(
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2 *)(&res_vec) + i) = make_half2(0, 0);
*((half2*)(&res_vec) + i) = make_half2(0, 0);
}
} else {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0);
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
float m;
@@ -2581,10 +2565,17 @@ __global__ void merge_multi_chunks_v2_kernel(
Load<T, vec_size>(smooth_weight + shift_smooth_offset,
&smooth_weight_vec);
}
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
StoreFunc<T, vec_size, OutT>()(
st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i);
StoreFunc<T, vec_size, OutT>()(st.o,
shift_bias_vec,
smooth_weight_vec,
out_vec,
quant_max_bound,
quant_min_bound,
in_scale,
i);
}
Store<OutT, vec_size>(
out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);