format code (#4720)
Some checks failed
CE Compile Job / ce_job_pre_check (push) Has been cancelled
Deploy GitHub Pages / deploy (push) Has been cancelled
CE Compile Job / print_ce_job_pre_check_outputs (push) Has been cancelled
CE Compile Job / FD-Clone-Linux (push) Has been cancelled
CE Compile Job / Show Code Archive Output (push) Has been cancelled
CE Compile Job / BUILD_SM8090 (push) Has been cancelled
CE Compile Job / BUILD_SM8689 (push) Has been cancelled
CE Compile Job / CE_UPLOAD (push) Has been cancelled
Publish Job / publish_pre_check (push) Has been cancelled
Publish Job / print_publish_pre_check_outputs (push) Has been cancelled
Publish Job / FD-Clone-Linux (push) Has been cancelled
Publish Job / Show Code Archive Output (push) Has been cancelled
Publish Job / BUILD_SM8090 (push) Has been cancelled
Publish Job / BUILD_SM8689 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8090 (push) Has been cancelled
Publish Job / PADDLE_PYPI_UPLOAD_8689 (push) Has been cancelled
Publish Job / Run FD Image Build (push) Has been cancelled
Publish Job / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
Publish Job / Run FastDeploy LogProb Tests (push) Has been cancelled
Publish Job / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
Publish Job / Run Base Tests (push) Has been cancelled
Publish Job / Run Accuracy Tests (push) Has been cancelled
Publish Job / Run Stable Tests (push) Has been cancelled
CI Images Build / FD-Clone-Linux (push) Has been cancelled
CI Images Build / Show Code Archive Output (push) Has been cancelled
CI Images Build / CI Images Build (push) Has been cancelled
CI Images Build / BUILD_SM8090 (push) Has been cancelled
CI Images Build / Run FastDeploy Unit Tests and Coverage (push) Has been cancelled
CI Images Build / Run FastDeploy LogProb Tests (push) Has been cancelled
CI Images Build / Extracted partial CE model tasks to run in CI. (push) Has been cancelled
CI Images Build / Run Base Tests (push) Has been cancelled
CI Images Build / Run Accuracy Tests (push) Has been cancelled
CI Images Build / Run Stable Tests (push) Has been cancelled
CI Images Build / Publish Docker Images Pre Check (push) Has been cancelled

This commit is contained in:
周周周
2025-11-01 19:13:50 +08:00
committed by GitHub
parent 4ac6de9a3c
commit 6e01be28e0
3 changed files with 379 additions and 379 deletions

View File

@@ -142,7 +142,6 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps(
const uint32_t tx_offset = tx / 8;
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset;
#pragma unroll
const int j = ty;
@@ -151,8 +150,7 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps(
const uint32_t h_offset = offset_now % group_size;
T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
q_smem_offset_w =
@@ -171,7 +169,7 @@ template <uint32_t group_size,
uint32_t HEAD_DIM,
typename T>
__device__ __forceinline__ void load_q_global_smem(
T* q_ptr_base,
const T* q_ptr_base,
smem_t* q_smem,
uint32_t q_idx_base,
const uint32_t qo_upper_bound,
@@ -194,10 +192,10 @@ __device__ __forceinline__ void load_q_global_smem(
const uint32_t offset_now = base_offset + j * 4;
const uint32_t n_offset = offset_now / group_size;
const uint32_t h_offset = offset_now % group_size;
T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
const T* q_ptr =
q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
q_smem_offset_w, q_ptr, n_offset < qo_upper_bound);
q_smem_offset_w =
@@ -223,8 +221,7 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps(
constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b<T>();
#pragma unroll
for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024;
++i) {
for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) {
const int offset = i * 1024 + ty * 256 + tx * 8;
Load<T, vec_size>(reinterpret_cast<T*>(q_smem->base) + offset, &tmp_vec);
#pragma unroll
@@ -289,11 +286,9 @@ __device__ __forceinline__ void produce_kv_blockwise(
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check
#pragma unroll
for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps;
++i) {
for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) {
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 4;
++j) {
for (uint32_t j = 0; j < num_frags_y / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, *gptr, kv_idx < kv_len);
*smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j);
*gptr += 8 * num_elems_per_128b<T>();
@@ -332,9 +327,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
block_size / num_elems_per_128b<CacheT>(); // 8
constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q;
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t kv_idx =
kv_idx_base +
tx % 4 * num_elems_per_128b<CacheT>();
uint32_t kv_idx = kv_idx_base + tx % 4 * num_elems_per_128b<CacheT>();
if constexpr (NUM_WARP_Q == 4) {
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
@@ -343,8 +336,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
for (uint32_t i = 0; i < num_frags_y * 2 / num_warps;
++i) { // m (num_frags_y * 16 / (num_warps * 8))
#pragma unroll
for (uint32_t j = 0; j < num_frags_z / 4;
++j) {
for (uint32_t j = 0; j < num_frags_z / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
*smem_offset = smem.advance_offset_by_column<4, num_vecs_per_blocksize>(
*smem_offset, j);
@@ -369,8 +361,7 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
for (uint32_t i = 0; i < num_frags_y * 2 / num_warps;
++i) { // m (num_frags_y * 16 / (num_warps * 8))
#pragma unroll
for (uint32_t j = 0; j < 2 * num_frags_z / 4;
++j) {
for (uint32_t j = 0; j < 2 * num_frags_z / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
*smem_offset =
smem.advance_offset_by_column<4, num_vecs_per_blocksize>(
@@ -392,27 +383,28 @@ __device__ __forceinline__ void produce_v_blockwise_c8(
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
template <uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_k_dynamic_scale(
T* k_smem_scale,
T* cache_k_reg,
const int* block_table_now,
const T* cache_k_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
T* k_smem_scale,
T* cache_k_reg,
const int* block_table_now,
const T* cache_k_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
// 4 warps shared block_size
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const T* cache_k_scale_now = cache_k_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
if (tid < block_size) {
k_smem_scale[tid] = cache_k_scale_now[tid];
}
@@ -427,10 +419,12 @@ __device__ __forceinline__ void produce_k_dynamic_scale(
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_k_scale_now = cache_k_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const T* cache_k_scale_now = cache_k_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
k_smem_scale[ty * 32 + tx] = cache_k_scale_now[(ty % 2) * 32 + tx];
} else {
k_smem_scale[ty * 32 + tx] = 0;
}
@@ -443,20 +437,19 @@ __device__ __forceinline__ void produce_k_dynamic_scale(
}
}
template<uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
template <uint32_t block_size,
uint32_t num_frags_z,
uint32_t NUM_WARP_Q,
typename T>
__device__ __forceinline__ void produce_v_dynamic_scale(
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end
) {
T* v_smem_scale,
T* cache_v_reg,
const int* block_table_now,
const T* cache_v_scale,
const uint32_t kv_idx,
const uint32_t kv_num_heads,
const uint32_t kv_head_idx,
const uint32_t chunk_end) {
const uint32_t tx = threadIdx.x, ty = threadIdx.y;
if constexpr (NUM_WARP_Q == 4) {
@@ -464,7 +457,9 @@ __device__ __forceinline__ void produce_v_dynamic_scale(
const uint32_t tid = ty * 32 + tx;
int block_id = __ldg(&block_table_now[kv_idx / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const T* cache_v_scale_now = cache_v_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
if (tid < block_size) {
v_smem_scale[tid] = cache_v_scale_now[tid];
}
@@ -481,10 +476,12 @@ __device__ __forceinline__ void produce_v_dynamic_scale(
const uint32_t kv_idx_now = kv_idx + block_size * ty / 2;
int block_id = __ldg(&block_table_now[kv_idx_now / block_size]);
if (block_id < 0) block_id = 0;
const T* cache_v_scale_now = cache_v_scale + block_id * kv_num_heads * block_size + kv_head_idx * block_size;
const T* cache_v_scale_now = cache_v_scale +
block_id * kv_num_heads * block_size +
kv_head_idx * block_size;
const int kv_idx_this_thread = kv_idx + ty * 32 + tx;
if (kv_idx_this_thread < chunk_end) {
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
v_smem_scale[ty * 32 + tx] = cache_v_scale_now[(ty % 2) * 32 + tx];
} else {
v_smem_scale[ty * 32 + tx] = 0;
}
@@ -560,8 +557,7 @@ __device__ __forceinline__ void produce_k_blockwise_c8(
for (uint32_t i = 0; i < 2 * num_frags_z * 4 / num_warps;
++i) { // m num_frags_z * 16 / (num_warps * 4)
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 8;
++j) {
for (uint32_t j = 0; j < num_frags_y / 8; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
*smem_offset = smem.advance_offset_by_column<8, num_vecs_per_head>(
*smem_offset, j);
@@ -614,8 +610,7 @@ __device__ __forceinline__ void produce_v_blockwise_c4(
#pragma unroll
for (uint32_t i = 0; i < num_frags_y / num_warps; ++i) { // m
#pragma unroll
for (uint32_t j = 0; j < num_frags_z / 4;
++j) {
for (uint32_t j = 0; j < num_frags_z / 4; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_v_now, true);
*smem_offset = smem.advance_offset_by_column<2, num_vecs_per_blocksize>(
*smem_offset, j);
@@ -671,8 +666,7 @@ __device__ __forceinline__ void produce_k_blockwise_c4(
for (uint32_t i = 0; i < num_frags_z * 2 / num_warps;
++i) { // m num_frags_z * 16 / (num_warps * 8)
#pragma unroll
for (uint32_t j = 0; j < num_frags_y / 8;
++j) {
for (uint32_t j = 0; j < num_frags_y / 8; ++j) {
smem.load_128b_async<fill_mode>(*smem_offset, cache_k_now, true);
*smem_offset = smem.advance_offset_by_column<4, num_vecs_per_head>(
*smem_offset, j);
@@ -937,7 +931,7 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
uint32_t* q_smem_offset_r,
smem_t* k_smem,
uint32_t* k_smem_offset_r,
const T *cache_k_scale,
const T* cache_k_scale,
float (*s_frag)[num_frags_z][8]) {
constexpr uint32_t head_dim = num_frags_y * 16;
constexpr uint32_t num_vecs_per_head_q = head_dim / num_elems_per_128b<T>();
@@ -973,8 +967,8 @@ __device__ __forceinline__ void compute_qk_c8(smem_t* q_smem,
#pragma unroll
for (uint32_t fy = 0; fy < 2; ++fy) {
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fy * 2]);
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fy * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
@@ -1036,7 +1030,7 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
const uint32_t chunk_end,
const uint32_t attn_mask_len,
float (*s_frag)[num_frags_z][8],
const int *mask_offset = nullptr,
const int* mask_offset = nullptr,
const int sliding_window = 0) {
const uint32_t tx = threadIdx.x;
#pragma unroll
@@ -1053,24 +1047,25 @@ __device__ __forceinline__ void mask_s(const bool* attn_mask,
8 * (reg_id / 4) + reg_id % 2;
bool out_of_boundary;
if (mask_offset) {
out_of_boundary = q_idx < qo_len ? (kv_idx >= mask_offset[q_idx * 2 + 1] || kv_idx < mask_offset[q_idx * 2]) : true;
}
else if (sliding_window > 0)
{
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx - (int)qo_len - sliding_window;
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || out_of_window || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
}
else
{
out_of_boundary =
(causal
? (kv_idx > kv_len + q_idx - qo_len || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len && kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx = q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
out_of_boundary = q_idx < qo_len
? (kv_idx >= mask_offset[q_idx * 2 + 1] ||
kv_idx < mask_offset[q_idx * 2])
: true;
} else if (sliding_window > 0) {
bool out_of_window = int(kv_idx) <= (int)kv_len + (int)q_idx -
(int)qo_len -
sliding_window;
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
out_of_window || (kv_idx >= chunk_end))
: kv_idx >= chunk_end);
} else {
out_of_boundary = (causal ? (kv_idx > kv_len + q_idx - qo_len ||
(kv_idx >= chunk_end))
: kv_idx >= chunk_end);
if (attn_mask != nullptr && kv_idx > kv_len - qo_len &&
kv_idx < chunk_end && q_idx < attn_mask_len) {
const int32_t mask_idx =
q_idx * attn_mask_len + kv_idx - kv_len + qo_len;
bool mask = attn_mask[mask_idx];
out_of_boundary |= mask;
}
@@ -1236,7 +1231,7 @@ __device__ __forceinline__ void compute_sfm_v_c8(
float (*s_frag)[num_frags_z][8],
float (*o_frag)[num_frags_y][8],
float (*d)[2],
const T *cache_v_scale) {
const T* cache_v_scale) {
constexpr uint32_t num_vecs_per_blocksize =
block_size / num_elems_per_128b<CacheT>();
T s_frag_f16[num_frags_x][num_frags_z][8];
@@ -1268,8 +1263,8 @@ __device__ __forceinline__ void compute_sfm_v_c8(
#pragma unroll
for (uint32_t fz = 0; fz < 2; ++fz) {
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
@@ -1300,7 +1295,6 @@ __device__ __forceinline__ void compute_sfm_v_c8(
o_frag[fx][fy],
(uint32_t*)(s_frag_f16[fx][kz * 2 + fz]),
b_frag_dq);
}
}
}
@@ -1328,7 +1322,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
float (*s_frag)[num_frags_z][8],
float (*o_frag)[num_frags_y][8],
float (*d)[2],
T *cache_v_scale) {
T* cache_v_scale) {
constexpr uint32_t num_vecs_per_blocksize =
block_size / num_elems_per_128b<CacheT>();
@@ -1362,8 +1356,8 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
for (uint32_t fz = 0; fz < 2; ++fz) {
// dequant b_frag -> b_frag_dq
T* b_frag_dq_T = reinterpret_cast<T*>(b_frag_dq);
convert_c8<T,IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T,IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
convert_c8<T, IsFP8>(b_frag_dq_T, b_frag[fz * 2]);
convert_c8<T, IsFP8>(b_frag_dq_T + 4, b_frag[fz * 2 + 1]);
// scale zp
if constexpr (!IsDynamicC8) {
if constexpr (is_scale_channel_wise) {
@@ -1372,7 +1366,7 @@ __device__ __forceinline__ void compute_sfm_v_c8_iter_sq_bvec(
b_frag_dq_T[b_i] *= cache_v_scale[b_i / 4 + fy * 2];
}
} else {
#pragma unroll
#pragma unroll
for (uint32_t b_i = 0; b_i < 8; ++b_i) {
b_frag_dq_T[b_i] *= cache_v_scale[0];
}
@@ -1431,8 +1425,7 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem,
}
#pragma unroll
for (uint32_t fz = 0; fz < num_frags_z;
++fz) {
for (uint32_t fz = 0; fz < num_frags_z; ++fz) {
#pragma unroll
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t b_frag[4];
@@ -1611,10 +1604,9 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
fx * 16 + tx / 4,
fy * 2);
uint32_t o_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head>(fx * 16 + tx / 4,
fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
@@ -1627,8 +1619,8 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
}
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * 4 + tx / 8, tx % 8);
uint32_t o_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head>(ty * 4 + tx / 8, tx % 8);
o_idx_base += (tx / 8) / group_size;
o_ptr_base += ((tx / 8) / group_size) * qo_n_stride +
@@ -1642,8 +1634,7 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride +
((fx * 16 + j * 4) % group_size) * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
if (o_idx < qo_upper_bound) {
// need write
o_smem->store_128b(o_smem_offset_w, o_ptr);
@@ -1658,7 +1649,6 @@ __device__ __forceinline__ void write_o_reg_gmem_kv_multi_warps(
}
}
template <typename T, int VEC_SIZE, typename OutT>
struct StoreFunc {
__device__ __forceinline__ void operator()(
@@ -1717,7 +1707,6 @@ struct StoreFunc<T, VEC_SIZE, __nv_fp8_e4m3> {
}
};
template <typename T, int VEC_SIZE>
struct StoreFunc<T, VEC_SIZE, T> {
__device__ __forceinline__ void operator()(
@@ -1770,10 +1759,9 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
fx * 16 + tx / 4,
fy * 2);
uint32_t o_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head>(fx * 16 + tx / 4,
fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
@@ -1786,8 +1774,8 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
}
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * 4 + tx / 8, tx % 8);
uint32_t o_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head>(ty * 4 + tx / 8, tx % 8);
const uint32_t tx_offset = tx / 8;
#pragma unroll
@@ -1804,8 +1792,7 @@ __device__ __forceinline__ void write_o_reg_gmem_multi_warps_shift_smooth_quant(
uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
tx % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
if (n_offset < qo_upper_bound) {
if constexpr (!partition_kv) {
Load<T, VEC_SIZE>(
@@ -1881,10 +1868,8 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
(ty * num_frags_x + fx) * 16 + tx / 4,
fy * 2);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
(ty * num_frags_x + fx) * 16 + tx / 4, fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
@@ -1897,8 +1882,7 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * num_frags_x * 16 + tx / 8,
tx % 8);
ty * num_frags_x * 16 + tx / 8, tx % 8);
const uint32_t tx_offset = tx / 8;
#pragma unroll
@@ -1914,13 +1898,12 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
uint32_t shift_smooth_offset = (q_head_idx_base + h_offset) * head_dim +
tx % 8 * num_elems_per_128b<T>();
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
if (n_offset < qo_upper_bound) {
if (!partition_kv) {
Load<T, VEC_SIZE>(
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
&ori_out_vec);
reinterpret_cast<T*>(o_smem->base + o_smem_offset_w),
&ori_out_vec);
if (in_scale > 0.0) {
if (shift_bias) {
Load<T, VEC_SIZE>(shift_bias + shift_smooth_offset,
@@ -1929,16 +1912,16 @@ __device__ __forceinline__ void write_o_reg_gmem_shift_smooth_quant(
&smooth_weight_vec);
}
}
#pragma unroll
#pragma unroll
for (int i = 0; i < VEC_SIZE; ++i) {
StoreFunc<T, VEC_SIZE, OutT>()(ori_out_vec,
shift_bias_vec,
smooth_weight_vec,
out_vec,
quant_max_bound,
quant_min_bound,
in_scale,
i);
shift_bias_vec,
smooth_weight_vec,
out_vec,
quant_max_bound,
quant_min_bound,
in_scale,
i);
}
Store<OutT, VEC_SIZE>(out_vec, o_ptr);
} else {
@@ -1979,10 +1962,8 @@ __device__ __forceinline__ void write_o_reg_gmem(
for (uint32_t fy = 0; fy < num_frags_y; ++fy) {
uint32_t o_frag_f16[4];
vec_cast<T, float, 8>((T*)o_frag_f16, o_frag[fx][fy]);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<
num_vecs_per_head>(
(ty * num_frags_x + fx) * 16 + tx / 4,
fy * 2);
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
(ty * num_frags_x + fx) * 16 + tx / 4, fy * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[tx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w +
8 * num_vecs_per_head))[tx % 4] = o_frag_f16[1];
@@ -1995,8 +1976,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
__syncthreads();
uint32_t o_smem_offset_w = smem_t::get_permuted_offset<num_vecs_per_head>(
ty * num_frags_x * 16 + tx / 8,
tx % 8);
ty * num_frags_x * 16 + tx / 8, tx % 8);
o_idx_base += (tx / 8) / group_size;
o_ptr_base += ((tx / 8) / group_size) * qo_n_stride +
@@ -2009,8 +1989,7 @@ __device__ __forceinline__ void write_o_reg_gmem(
T* o_ptr = o_ptr_base + ((fx * 16 + j * 4) / group_size) * qo_n_stride +
((fx * 16 + j * 4) % group_size) * qo_h_stride;
#pragma unroll
for (uint32_t fyo = 0; fyo < num_frags_y / 4;
++fyo) {
for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) {
if (o_idx < qo_upper_bound) {
o_smem->store_128b(o_smem_offset_w, o_ptr);
}
@@ -2125,7 +2104,6 @@ __global__ void merge_multi_chunks_kernel(
&out[(qid * num_heads + hid) * head_dim + vid * vec_size]);
}
template <uint32_t num_frags_x, uint32_t num_frags_y, typename T>
__device__ __forceinline__ void merge_block_res(float (*o_frag)[num_frags_y][8],
float* md_smem,
@@ -2307,18 +2285,18 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true>
__global__ void merge_multi_chunks_decoder_kernel(
const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads,
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
// head_dim]
const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
OutT *__restrict__ out,
const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int* __restrict__ seq_lens_q,
const int* __restrict__ seq_lens_kv,
const int* __restrict__ seq_lens_encoder,
const int* __restrict__ cu_seqlens_q,
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T* __restrict__ sinks, // [q_num_heads]
OutT* __restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -2419,8 +2397,14 @@ __global__ void merge_multi_chunks_decoder_kernel(
}
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
StoreFunc<T, vec_size, OutT>()(
st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i);
StoreFunc<T, vec_size, OutT>()(st.o,
shift_bias_vec,
smooth_weight_vec,
out_vec,
quant_max_bound,
quant_min_bound,
in_scale,
i);
}
Store<OutT, vec_size>(
out_vec,
@@ -2435,19 +2419,19 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true>
__global__ void merge_multi_chunks_v2_kernel(
const T *__restrict__ multi_out, // [token_num, num_chunks, num_heads,
const T* __restrict__ multi_out, // [token_num, num_chunks, num_heads,
// head_dim]
const float *__restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float *__restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int *__restrict__ seq_lens_q,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ seq_lens_encoder,
const int *__restrict__ batch_id_per_token,
const int *__restrict__ cu_seqlens_q,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
OutT *__restrict__ out,
const float* __restrict__ multi_m, // [token_num, num_chunks, num_heads]
const float* __restrict__ multi_d, // [token_num, num_chunks, num_heads]
const int* __restrict__ seq_lens_q,
const int* __restrict__ seq_lens_kv,
const int* __restrict__ seq_lens_encoder,
const int* __restrict__ batch_id_per_token,
const int* __restrict__ cu_seqlens_q,
const T* __restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T* __restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T* __restrict__ sinks, // [q_num_heads]
OutT* __restrict__ out,
const float quant_max_bound,
const float quant_min_bound,
const float in_scale,
@@ -2464,7 +2448,7 @@ __global__ void merge_multi_chunks_v2_kernel(
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
if(bid == -1){
if (bid == -1) {
continue;
}
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
@@ -2486,7 +2470,7 @@ __global__ void merge_multi_chunks_v2_kernel(
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
continue;
}else if (!ENABLE_PREFILL){
} else if (!ENABLE_PREFILL) {
continue;
}
@@ -2496,12 +2480,12 @@ __global__ void merge_multi_chunks_v2_kernel(
if constexpr (std::is_same<T, half>::value) {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((half2 *)(&res_vec) + i) = make_half2(0, 0);
*((half2*)(&res_vec) + i) = make_half2(0, 0);
}
} else {
#pragma unroll
for (int i = 0; i < vec_size / 2; ++i) {
*((nv_bfloat162 *)(&res_vec) + i) = make_bfloat162(0, 0);
*((nv_bfloat162*)(&res_vec) + i) = make_bfloat162(0, 0);
}
}
float m;
@@ -2581,10 +2565,17 @@ __global__ void merge_multi_chunks_v2_kernel(
Load<T, vec_size>(smooth_weight + shift_smooth_offset,
&smooth_weight_vec);
}
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
StoreFunc<T, vec_size, OutT>()(
st.o, shift_bias_vec, smooth_weight_vec, out_vec, quant_max_bound, quant_min_bound, in_scale, i);
StoreFunc<T, vec_size, OutT>()(st.o,
shift_bias_vec,
smooth_weight_vec,
out_vec,
quant_max_bound,
quant_min_bound,
in_scale,
i);
}
Store<OutT, vec_size>(
out_vec, &out[(qid * num_heads + hid) * head_dim + vid * vec_size]);

View File

@@ -20,12 +20,11 @@
#include "utils.cuh"
template <int THREADBLOCK_SIZE>
__global__ void
GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
const int *seq_lens_encoder,
const int *seq_lens_this_time_merged,
const int *seq_lens_encoder_merged, const int *seq_mapping,
const int *system_lens, int *max_lens, const int batch_size) {
__global__ void GetMaxLenKernel(const int *seq_lens_decoder,
const int *seq_lens_this_time,
const int *seq_lens_encoder,
int *max_lens,
const int batch_size) {
const int tid = threadIdx.x;
typedef cub::BlockReduce<int, THREADBLOCK_SIZE> BlockReduce;
@@ -36,9 +35,6 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
int max_len_decoder_this_thread = 0;
int max_len_this_thread = 0;
int max_just_dec_len_this_thread = 0;
int max_just_dec_merged_len_this_time_this_thread = 0;
int max_system_len_this_thread = 0;
int max_dec_len_without_system_this_thread = 0;
int max_len_kv_this_thread = 0;
for (int i = tid; i < batch_size; i += blockDim.x) {
const int seq_len_this_time = seq_lens_this_time[i];
@@ -47,17 +43,17 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
max(seq_len_this_time, max_len_this_time_this_thread);
max_len_encoder_this_thread =
max(seq_lens_encoder[i], max_len_encoder_this_thread);
max_len_decoder_this_thread = max(seq_len_decoder, max_len_decoder_this_thread);
if (seq_len_this_time <= 0)
continue;
const int max_just_dec_len_now = seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder;
max_len_decoder_this_thread =
max(seq_len_decoder, max_len_decoder_this_thread);
if (seq_len_this_time <= 0) continue;
const int max_just_dec_len_now =
seq_lens_encoder[i] > 0 ? 0 : seq_len_decoder;
max_len_this_thread =
max(seq_len_decoder + seq_len_this_time, max_len_this_thread);
max_just_dec_len_this_thread =
max(max_just_dec_len_this_thread, max_just_dec_len_now);
if (seq_len_decoder == 0)
continue;
if (seq_len_decoder == 0) continue;
max_len_kv_this_thread =
max(seq_len_this_time + seq_len_decoder, max_len_kv_this_thread);
}
@@ -74,14 +70,6 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
BlockReduce(temp_storage).Reduce(max_len_this_thread, MaxOp<int>());
int total_just_dec = BlockReduce(temp_storage)
.Reduce(max_just_dec_len_this_thread, MaxOp<int>());
int total_just_dec_merged =
BlockReduce(temp_storage)
.Reduce(max_just_dec_merged_len_this_time_this_thread, MaxOp<int>());
int total_system_len = BlockReduce(temp_storage)
.Reduce(max_system_len_this_thread, MaxOp<int>());
int total_dec_len_without_system =
BlockReduce(temp_storage)
.Reduce(max_dec_len_without_system_this_thread, MaxOp<int>());
int total_max_len_kv =
BlockReduce(temp_storage).Reduce(max_len_kv_this_thread, MaxOp<int>());
if (tid == 0) {
@@ -90,24 +78,10 @@ GetMaxLenKernel(const int *seq_lens_decoder, const int *seq_lens_this_time,
max_lens[2] = total_max_len_decoder;
max_lens[3] = total;
max_lens[4] = total_just_dec;
max_lens[5] = total_just_dec_merged;
max_lens[6] = total_system_len;
max_lens[7] = total_dec_len_without_system;
max_lens[8] = total_max_len_kv;
}
}
void GetMaxLen(const paddle::Tensor &seq_lens_tensor,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &seq_lens_encoder,
paddle::Tensor &max_len_tensor, const int batch_size) {
constexpr int blockSize = 1024;
GetMaxLenKernel<blockSize><<<1, blockSize, 0, seq_lens_encoder.stream()>>>(
seq_lens_tensor.data<int>(), seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(), nullptr, nullptr, nullptr, nullptr,
max_len_tensor.data<int>(), batch_size);
}
template <uint32_t config_size>
__global__ void search_chunk_size_for_mla(
const int *__restrict__ seq_lens_q,
@@ -154,11 +128,11 @@ __global__ void search_chunk_size_for_mla(
uint32_t res_id = 0;
uint32_t max_last_wave_block = 0;
for (uint32_t i = 1; i < config_size; i++) {
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
if (last_wave_block >= max_last_wave_block) {
res_id = i;
max_last_wave_block = last_wave_block;
}
uint32_t last_wave_block = gridx_shared[i] % sm_cout;
if (last_wave_block >= max_last_wave_block) {
res_id = i;
max_last_wave_block = last_wave_block;
}
}
*num_blocks_x = gridx_shared[res_id];
*res_chunk_size = block_size << res_id;
@@ -185,11 +159,11 @@ __global__ void split_block_for_mla(const int *__restrict__ seq_lens_q,
int loop_times;
loop_times = cute::ceil_div(seq_len_decoder, chunk_size);
if (seq_len_encoder > 0) {
loop_times = 0;
loop_times = 0;
}
for (uint32_t tile_id = 0; tile_id < loop_times; tile_id++) {
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
batch_ids[index] = bid;
tile_ids_per_batch[index++] = tile_id;
}
}
}
@@ -255,8 +229,10 @@ __global__ void split_kv_block(const int *__restrict__ seq_lens_decoder,
const int *__restrict__ seq_lens_encoder,
int *__restrict__ batch_ids,
int *__restrict__ tile_ids_per_batch,
int *__restrict__ num_blocks_x, const int bsz,
const int pad_len, const int num_row_per_block) {
int *__restrict__ num_blocks_x,
const int bsz,
const int pad_len,
const int num_row_per_block) {
if (threadIdx.x == 0) {
int gridx = 0;
int index = 0;
@@ -281,31 +257,39 @@ void GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
paddle::Tensor &decoder_num_blocks_device, // Inplace
paddle::Tensor &decoder_chunk_size_device, // Inplace
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_cpu, // Inplace, Pinned Memory
paddle::Tensor &decoder_num_blocks_device, // Inplace
paddle::Tensor &decoder_chunk_size_device, // Inplace
paddle::Tensor &max_len_tensor_cpu, // Inplace, CPU
paddle::Tensor &encoder_batch_ids, // Inplace
paddle::Tensor &encoder_tile_ids_per_batch, // Inplace
paddle::Tensor &encoder_num_blocks_x_cpu, // Inplace, CPU
paddle::Tensor &kv_batch_ids, // Inplace
paddle::Tensor &kv_tile_ids_per_batch, // Inplace
paddle::Tensor &kv_num_blocks_x_cpu, // Inplace, CPU
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num)
{
const int decoder_step_token_num) {
auto stream = seq_lens_encoder.stream();
int bsz = seq_lens_this_time.shape()[0];
paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
max_len_tensor_gpu, bsz);
max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
paddle::Tensor max_len_tensor_gpu =
GetEmptyTensor({max_len_tensor_cpu.shape()[0]},
paddle::DataType::INT32,
seq_lens_this_time.place());
GetMaxLenKernel<1024><<<1, 1024, 0, stream>>>(seq_lens_decoder.data<int>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
max_len_tensor_gpu.data<int>(),
bsz);
max_len_tensor_cpu.copy_(
max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
int max_len_this_time = max_len_cpu_ptr[0];
@@ -320,7 +304,6 @@ void GetBlockShapeAndSplitKVBlock(
// decoder
if (max_dec_len_this_time > 0) {
const bool mla_backend = checkAttentionBackend();
if (mla_backend && group_size <= 64) {
const int set_chunk_size = get_mla_dec_chunk_size(bsz);
@@ -356,8 +339,9 @@ void GetBlockShapeAndSplitKVBlock(
const int chunk_size = decoder_chunk_size_cpu.data<int>()[0];
// NOTE: (changwenbin) When using auto_chunk,
// decode_max_tile_size must take into account the maximum case, where * 1024 can cover 128K.
// const uint32_t decoder_batch_shape = seq_lens_decoder.dims()[0] * 1024;
// decode_max_tile_size must take into account the maximum case, where *
// 1024 can cover 128K. const uint32_t decoder_batch_shape =
// seq_lens_decoder.dims()[0] * 1024;
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
@@ -375,7 +359,6 @@ void GetBlockShapeAndSplitKVBlock(
decoder_batch_shape * sizeof(int32_t),
stream));
split_block_for_mla<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
@@ -419,49 +402,72 @@ void GetBlockShapeAndSplitKVBlock(
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
}
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_chunk_size_device.data<int>(), 64, sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
decoder_num_blocks_device.data<int>(), 0, sizeof(int32_t), stream));
decoder_num_blocks_cpu.copy_(
decoder_num_blocks_device, decoder_num_blocks_cpu.place(), false);
}
// encoder
if (max_enc_len_this_time > 0) {
const uint32_t max_tile_size_per_bs_kv = div_up(max_enc_dec_len_this_time, block_size);
const uint32_t max_tile_size_per_bs_kv =
div_up(max_enc_dec_len_this_time, block_size);
const uint32_t kv_batch_shape = bsz * max_tile_size_per_bs_kv;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(
kv_batch_ids.data<int>(), 0, kv_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(kv_tile_ids_per_batch.data<int>(),
0,
kv_batch_shape * sizeof(int32_t),
stream));
auto kv_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_kv_block<<<1, 32, 0, seq_lens_encoder.stream()>>>(
seq_lens_decoder.data<int>(),
// sequence_lengths->data<int>(),
seq_lens_encoder.data<int>(), kv_batch_ids.data<int>(),
kv_tile_ids_per_batch.data<int>(), kv_num_blocks_x.data<int>(), bsz,
block_size, block_size);
seq_lens_encoder.data<int>(),
kv_batch_ids.data<int>(),
kv_tile_ids_per_batch.data<int>(),
kv_num_blocks_x.data<int>(),
bsz,
block_size,
block_size);
kv_num_blocks_x_cpu.copy_(kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
kv_num_blocks_x_cpu.copy_(
kv_num_blocks_x, kv_num_blocks_x_cpu.place(), false);
// Clear buffer
const uint32_t encoder_max_tile_size_per_bs_q = div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
const uint32_t encoder_max_tile_size_per_bs_q =
div_up((max_enc_dec_len_this_time * group_size), encoder_block_shape_q);
const uint32_t encoder_batch_shape = bsz * encoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_batch_ids.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(), 0, encoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(encoder_batch_ids.data<int>(),
0,
encoder_batch_shape * sizeof(int32_t),
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(encoder_tile_ids_per_batch.data<int>(),
0,
encoder_batch_shape * sizeof(int32_t),
stream));
auto encoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(), nullptr,
split_q_block<<<1, 32, 0, stream>>>(seq_lens_encoder.data<int>(),
nullptr,
encoder_batch_ids.data<int>(),
encoder_tile_ids_per_batch.data<int>(),
encoder_num_blocks_x.data<int>(), bsz,
encoder_block_shape_q, group_size);
encoder_num_blocks_x_cpu.copy_(encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
encoder_num_blocks_x.data<int>(),
bsz,
encoder_block_shape_q,
group_size);
encoder_num_blocks_x_cpu.copy_(
encoder_num_blocks_x, encoder_num_blocks_x_cpu.place(), false);
}
}
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
@@ -472,8 +478,7 @@ std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num
) {
const int decoder_step_token_num) {
return {};
}
@@ -485,39 +490,36 @@ std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num
) {
const int decoder_step_token_num) {
return {};
}
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Inputs({
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks_cpu",
"decoder_num_blocks_device",
"decoder_chunk_size_device",
"max_len_tensor_cpu",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks_x_cpu",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks_x_cpu",
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks_cpu",
"decoder_num_blocks_device",
"decoder_chunk_size_device",
"max_len_tensor_cpu",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",
"encoder_num_blocks_x_cpu",
"kv_batch_ids",
"kv_tile_ids_per_batch",
"kv_num_blocks_x_cpu",
})
.Outputs({
})
.Attrs({
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"group_size: int",
"block_size: int",
"decoder_step_token_num: int"
})
.Attrs({"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"group_size: int",
"block_size: int",
"decoder_step_token_num: int"})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));

View File

@@ -21,7 +21,6 @@ template <typename T,
bool CAUSAL,
uint32_t NUM_WARPS,
uint32_t NUM_WARP_Q,
uint32_t NUM_WARP_KV,
uint32_t HEAD_DIM,
uint32_t BLOCK_SIZE,
uint32_t num_frags_x,
@@ -30,13 +29,14 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true>
__global__ void multi_query_append_attention_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
const T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) *
// head_dim]
T *__restrict__ cache_v,
const T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
const T *__restrict__ cache_v,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
@@ -45,7 +45,6 @@ __global__ void multi_query_append_attention_kernel(
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
const float scale,
const float quant_max_bound,
@@ -72,12 +71,10 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t batch_id = batch_ids[btid];
const uint32_t tile_id = tile_ids_per_batch[btid];
const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16;
const int *block_table_now = nullptr;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
// When cudagraph capture prefill, may launch more gridDim.x
if (btid >= static_cast<uint32_t>(num_blocks_x_cpu)) {
return;
}
@@ -85,8 +82,7 @@ __global__ void multi_query_append_attention_kernel(
if (q_len <= 0) {
return;
}
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
uint32_t kv_len = seq_lens_kv[batch_id];
if (ENABLE_PREFILL) {
kv_len += q_len;
@@ -111,6 +107,9 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t chunk_len = chunk_end - chunk_start;
extern __shared__ uint8_t smem[];
static_assert(num_frags_y * 16 == HEAD_DIM);
static_assert(num_frags_z * 16 == BLOCK_SIZE);
float s_frag[num_frags_x][num_frags_z][8];
float o_frag[num_frags_x][num_frags_y][8];
float m_frag[num_frags_x][2];
@@ -131,7 +130,7 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t o_offset = q_start_seq_id * q_n_stride +
q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
T *q_base_ptr = q + q_offset;
const T *q_base_ptr = q + q_offset;
T *o_base_ptr_T = nullptr;
OutT *o_base_ptr_int8 = nullptr;
if constexpr (partition_kv) {
@@ -149,11 +148,16 @@ __global__ void multi_query_append_attention_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq =
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
load_q_global_smem<GROUP_SIZE, num_frags_x, num_frags_y, HEAD_DIM, T>(
q_base_ptr,
&qo_smem,
@@ -172,7 +176,6 @@ __global__ void multi_query_append_attention_kernel(
v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM *
sizeof(T));
const uint32_t num_iterations = div_up(
CAUSAL
? (min(chunk_len,
@@ -183,12 +186,13 @@ __global__ void multi_query_append_attention_kernel(
: chunk_len,
num_frags_z * 16);
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: mask_offset ? 0
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
8 * (tid / 16) + tid % 8, (tid % 16) / 8);
@@ -204,8 +208,8 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t const_offset = kv_head_idx * kv_h_stride +
(wid * 4 + tid / 8) * kv_b_stride +
tid % 8 * num_elems_per_128b<T>();
T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
const T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset;
const T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset;
produce_kv_blockwise<SharedMemFillMode::kNoFill,
NUM_WARPS,
@@ -264,7 +268,6 @@ __global__ void multi_query_append_attention_kernel(
s_frag,
mask_offset_this_seq,
sliding_window);
}
// update m,d
@@ -321,18 +324,22 @@ __global__ void multi_query_append_attention_kernel(
wait_group<0>();
__syncthreads();
if constexpr (!partition_kv ) {
if constexpr (!partition_kv) {
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
const uint32_t h_offset =
(q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) %
GROUP_SIZE;
current_sinks[fx][j] =
static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
normalize_d<num_frags_x, num_frags_y>(
o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
@@ -375,7 +382,6 @@ __global__ void multi_query_append_attention_kernel(
HEAD_DIM);
}
if constexpr (partition_kv) {
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
@@ -421,13 +427,13 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true>
__global__ void multi_query_append_attention_warp1_4_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
T *__restrict__ cache_v,
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
@@ -435,9 +441,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
const float scale,
const float quant_max_bound,
@@ -469,8 +474,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
// When cudagraph capture prefill, may launch more gridDim.x
if (btid >= static_cast<uint32_t>(num_blocks_x_cpu)) {
return;
}
@@ -478,8 +483,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
if (q_len <= 0) {
return;
}
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
uint32_t kv_len = seq_lens_kv[batch_id];
if (ENABLE_PREFILL) {
kv_len += q_len;
@@ -540,11 +544,16 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq =
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
tid % 16, tid / 16); // 16 * 16
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
load_q_global_smem_multi_warps<GROUP_SIZE,
num_frags_x,
num_frags_y,
@@ -576,11 +585,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
: chunk_len,
NUM_WARP_KV * num_frags_z * 16);
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(kv_len - q_len, chunk_start)))
: mask_offset ? 0
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -648,16 +656,18 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq,
sliding_window);
num_frags_z>(
attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len
: nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq,
sliding_window);
}
// update m,d
@@ -720,15 +730,19 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
if (num_chunks_this_seq <= 1) {
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
const uint32_t h_offset =
(q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) %
GROUP_SIZE;
current_sinks[fx][j] =
static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
normalize_d<num_frags_x, num_frags_y>(
o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
@@ -876,7 +890,6 @@ void MultiQueryAppendAttention(
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
@@ -908,7 +921,6 @@ void MultiQueryAppendAttention(
CAUSAL,
num_warps,
NUM_WARP_Q,
NUM_WARP_KV,
HEAD_DIM,
BLOCK_SIZE,
num_frags_x,
@@ -933,8 +945,8 @@ void MultiQueryAppendAttention(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -943,7 +955,6 @@ void MultiQueryAppendAttention(
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
scale,
quant_max_bound,
@@ -996,8 +1007,8 @@ void MultiQueryAppendAttention(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1006,7 +1017,6 @@ void MultiQueryAppendAttention(
block_table.data<int>(),
meta_data.mask_offset,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
scale,
quant_max_bound,
@@ -1048,8 +1058,8 @@ void MultiQueryAppendAttention(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1087,8 +1097,8 @@ void MultiQueryAppendAttention(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1138,9 +1148,9 @@ void MultiQueryAppendAttention(
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
attn_mask_len = -1;
}
const int num_chunks = div_up(max_seq_len, chunk_size);
@@ -1179,8 +1189,8 @@ void MultiQueryAppendAttention(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1189,9 +1199,8 @@ void MultiQueryAppendAttention(
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
scale,
quant_max_bound,
@@ -1250,14 +1259,14 @@ void MultiQueryAppendAttention(
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1266,9 +1275,8 @@ void MultiQueryAppendAttention(
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
scale,
quant_max_bound,
@@ -1306,14 +1314,14 @@ void MultiQueryAppendAttention(
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
@@ -1326,15 +1334,14 @@ void MultiQueryAppendAttention(
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
@@ -1345,14 +1352,14 @@ void MultiQueryAppendAttention(
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,