[Feature][OP] Append Attn Support CUDA-PDL (#5072)

This commit is contained in:
chen
2025-11-17 20:47:33 +08:00
committed by GitHub
parent c2c1942db9
commit d58c1db8a0
12 changed files with 2828 additions and 2068 deletions

View File

@@ -2296,6 +2296,9 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int bid = blockIdx.x, hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
const int start_token_idx = cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) return;
@@ -2332,6 +2335,10 @@ __global__ void merge_multi_chunks_decoder_kernel(
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
m = -3.0e+30f;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
#pragma unroll 2
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset = (bid * num_chunks + i) * num_heads + hid;
@@ -2397,6 +2404,9 @@ __global__ void merge_multi_chunks_decoder_kernel(
out_vec,
&out[(start_token_idx * num_heads + hid) * head_dim + vid * vec_size]);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T,
@@ -2433,6 +2443,9 @@ __global__ void merge_multi_chunks_v2_kernel(
const int hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
if (bid == -1) {
@@ -2569,4 +2582,7 @@ __global__ void merge_multi_chunks_v2_kernel(
}
__syncthreads();
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

View File

@@ -109,6 +109,9 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
const int half_head_size = head_size / 2;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim;
gloabl_hi += all_warp_num) {
int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize;
@@ -198,6 +201,9 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 1>
@@ -239,6 +245,9 @@ __global__ void append_decode_cache_T_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
// const int64_t offset = 2 * hidden_size;
const int half_head_size = head_size / 2;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
@@ -305,10 +314,13 @@ __global__ void append_decode_cache_T_rope_kernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_rope_kernel(
__global__ void append_decode_cache_T_quant_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
@@ -352,6 +364,9 @@ __global__ void append_decode_cache_T_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
// const int64_t offset = 2 * hidden_size;
const int half_head_size = head_size / 2;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
@@ -427,6 +442,9 @@ __global__ void append_decode_cache_T_rope_kernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 1>
@@ -473,7 +491,9 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
const int64_t half_hidden_size = hidden_size / 2;
// const int64_t offset = 2 * hidden_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
@@ -566,6 +586,9 @@ __global__ void append_decode_cache_T_neox_partial_rope_kernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 1>
@@ -608,7 +631,9 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
const int64_t half_hidden_size = hidden_size / 2;
// const int64_t offset = 2 * hidden_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
@@ -680,10 +705,13 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_neox_rope_kernel(
__global__ void append_decode_cache_T_quant_neox_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
@@ -726,7 +754,9 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
const int half_head_size = head_size / 2;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
const int64_t half_hidden_size = hidden_size / 2;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
@@ -814,6 +844,9 @@ __global__ void append_decode_cache_T_neox_rope_kernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T,
@@ -872,7 +905,9 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
float thread_m2 = 0.0f;
float warp_m2 = 0.0f;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<T, VecSize>;
@@ -1118,6 +1153,9 @@ __global__ void append_decode_cache_int8_rope_qk_norm_kernel(
value_cache[tgt_cache_idx4] = cache_vec[3];
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T,
@@ -1169,7 +1207,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
block_table_now = block_tables + bid * max_blocks_per_seq;
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (head_idx < num_heads) {
// q
const T* qkv_now =
@@ -1356,6 +1396,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
value_cache[tgt_cache_idx4] = cache_vec[3];
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T,
@@ -1364,7 +1407,7 @@ template <typename T,
int HeadDim = 128,
bool is_scale_channel_wise = false,
bool IsFP8 = false>
__global__ void append_decode_cache_int8_rope_kernel(
__global__ void int_append_decode_cache_int8_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
@@ -1412,7 +1455,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
block_table_now = block_tables + bid * max_blocks_per_seq;
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<int, VecSize>;
@@ -1674,6 +1719,9 @@ __global__ void append_decode_cache_int8_rope_kernel(
value_cache[tgt_cache_idx4] = cache_vec[3];
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
@@ -1721,7 +1769,9 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
block_table_now = block_tables + bid * max_blocks_per_seq;
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<T, VecSize>;
@@ -1977,10 +2027,13 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
value_cache[tgt_cache_idx4] = cache_vec[3];
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
__global__ void append_decode_cache_int8_neox_rope_kernel(
__global__ void int_append_decode_cache_int8_neox_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
@@ -2030,7 +2083,9 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
block_table_now = block_tables + bid * max_blocks_per_seq;
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<int, VecSize>;
@@ -2374,6 +2429,9 @@ __global__ void append_decode_cache_int8_neox_rope_kernel(
value_cache[tgt_cache_idx4] = cache_vec[3];
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
@@ -2424,7 +2482,9 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (head_idx < num_heads) {
// q
const T* qkv_now =
@@ -2648,10 +2708,13 @@ __global__ void append_decode_cache_int4_rope_kernel(
(uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
__global__ void append_decode_cache_int4_rope_kernel(
__global__ void int_append_decode_cache_int4_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
@@ -2703,7 +2766,9 @@ __global__ void append_decode_cache_int4_rope_kernel(
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<int, VecSize>;
@@ -2981,6 +3046,9 @@ __global__ void append_decode_cache_int4_rope_kernel(
(uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
@@ -3031,7 +3099,9 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<T, VecSize>;
@@ -3355,10 +3425,13 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
(uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T, int VecSize = 4, int RoundType = 0, int HeadDim = 128>
__global__ void append_decode_cache_int4_neox_rope_kernel(
__global__ void int_append_decode_cache_int4_neox_rope_kernel(
const int* __restrict__ quant_qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
uint8_t* __restrict__ key_cache, // [num_blocks, kv_num_heads,
@@ -3410,7 +3483,9 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
const int block_idx = __ldg(&block_table_now[write_seq_id / block_size]);
const int block_offset = write_seq_id % block_size;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
if (head_idx < num_heads) {
// q
using LoadT = AlignedVector<int, VecSize>;
@@ -3808,4 +3883,7 @@ __global__ void append_decode_cache_int4_neox_rope_kernel(
(uint_quant_value2 << 4) | (uint_quant_value1 & 0x0F);
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

View File

@@ -52,28 +52,33 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
int grid_size = 1;
GetNumBlocks<128>(pack_num, &grid_size);
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
launchWithPdlWhenEnabled(
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>,
grid_size,
block_dim,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d,
q_norm_weight,
k_norm_weight,
rms_norm_eps);
}
template <typename T, typename QKV_TYPE>
@@ -111,118 +116,140 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
GetNumBlocks<128>(pack_num, &grid_size);
if (use_neox_style) {
if (qkv_out_scales) {
append_decode_cache_T_neox_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(
append_decode_cache_T_quant_neox_rope_kernel<T, PackSize>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
} else {
if (rotary_dim < dim_head) {
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
rotary_dim,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
auto* kernelFn =
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>;
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
rotary_dim,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
} else {
append_decode_cache_T_neox_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
auto* kernelFn = append_decode_cache_T_neox_rope_kernel<T, PackSize>;
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
}
}
} else {
if (qkv_out_scales) {
append_decode_cache_T_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(
append_decode_cache_T_quant_rope_kernel<T, PackSize>,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
} else {
append_decode_cache_T_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
auto* kernelFn = append_decode_cache_T_rope_kernel<T, PackSize>;
launchWithPdlWhenEnabled(kernelFn,
grid_size,
blocksize,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
}
}
}
@@ -261,113 +288,128 @@ void append_decode_cache_int8_rope(const QKV_TYPE* qkv,
dim3 grids(bsz, all_warps / num_warps);
if (use_neox_style) {
if (qkv_out_scales) {
append_decode_cache_int8_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
cache_k_scale,
cache_v_scale,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(
int_append_decode_cache_int8_neox_rope_kernel<T, 4>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
cache_k_scale,
cache_v_scale,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
} else {
append_decode_cache_int8_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(append_decode_cache_int8_neox_rope_kernel<T, 4>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
}
} else {
if (qkv_out_scales) {
append_decode_cache_int8_rope_kernel<T,
4,
0,
128,
is_scale_channel_wise,
IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
cache_k_scale,
cache_v_scale,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(
int_append_decode_cache_int8_rope_kernel<T,
4,
0,
128,
is_scale_channel_wise,
IsFP8>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
cache_k_scale,
cache_v_scale,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
} else {
append_decode_cache_int8_rope_kernel<T,
4,
0,
128,
is_scale_channel_wise,
IsFP8>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(
append_decode_cache_int8_rope_kernel<T,
4,
0,
128,
is_scale_channel_wise,
IsFP8>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d);
}
}
}
@@ -405,111 +447,124 @@ void append_decode_cache_int4_rope(const QKV_TYPE* qkv,
dim3 grids(bsz, all_warps / num_warps);
if (use_neox_style) {
if (qkv_out_scales) {
append_decode_cache_int4_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(
int_append_decode_cache_int4_neox_rope_kernel<T, 4>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
} else {
append_decode_cache_int4_neox_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(append_decode_cache_int4_neox_rope_kernel<T, 4>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
}
} else {
if (qkv_out_scales) {
append_decode_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(int_append_decode_cache_int4_rope_kernel<T, 4>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const int*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
qkv_out_scales,
qkv_biases,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
} else {
append_decode_cache_int4_rope_kernel<T, 4>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
launchWithPdlWhenEnabled(append_decode_cache_int4_rope_kernel<T, 4>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
cache_k_scale,
cache_v_scale,
cache_k_zp,
cache_v_zp,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
7.0f,
-8.0f,
kv_num_heads,
rope_3d);
}
}
}
@@ -610,77 +665,85 @@ void DecoderWriteCacheWithRoPEKernel(
const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) /
num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
4,
0,
128,
false,
true,
true>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
(cache_v_scale.get().data<T>()))),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
launchWithPdlWhenEnabled(
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
4,
0,
128,
false,
true,
true>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
(cache_v_scale.get().data<T>()))),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else if ((cache_quant_type_str == "cache_fp8")) {
constexpr int num_warps = 4;
const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) /
num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
4,
0,
128,
false,
true,
false>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
(cache_v_scale.get().data<T>()))),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
launchWithPdlWhenEnabled(
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
4,
0,
128,
false,
true,
false>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
(cache_v_scale.get().data<T>()))),
q_norm_weight.get().data<float>(),
k_norm_weight.get().data<float>(),
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else {
PD_THROW(
"append_decode_cache_rope_qk_norm just supports cache_quant_type "
@@ -822,38 +885,42 @@ void DecoderWriteCacheWithRoPEKernel(
const int all_warps = ((num_heads + 2 * kv_num_heads) + num_warps - 1) /
num_warps * num_warps;
dim3 grids(bsz, all_warps / num_warps);
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
4,
0,
128,
false,
true>
<<<grids, num_warps * 32, 0, stream>>>(
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
(cache_v_scale.get().data<T>()))),
nullptr,
nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
launchWithPdlWhenEnabled(
append_decode_cache_int8_rope_qk_norm_kernel<DataType_,
4,
0,
128,
false,
true>,
grids,
num_warps * 32,
0,
stream,
reinterpret_cast<const DataType_*>(qkv_ptr),
key_cache_out->data<uint8_t>(),
value_cache_out->data<uint8_t>(),
reinterpret_cast<DataType_*>(qkv_out->data<T>()),
block_tables.data<int>(),
cu_seqlens_q.data<int>(),
seq_lens.data<int>(),
seq_lens_encoder.data<int>(),
cos_emb,
sin_emb,
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
cache_k_scale.get().data<T>())),
const_cast<DataType_*>(reinterpret_cast<const DataType_*>(
(cache_v_scale.get().data<T>()))),
nullptr,
nullptr,
max_seq_len,
max_blocks_per_seq,
num_heads,
block_size,
127.0f,
-127.0f,
kv_num_heads,
rope_3d,
rms_norm_eps);
} else if (cache_quant_type_str == "cache_int4_zp") {
append_decode_cache_int4_rope(
reinterpret_cast<const QKV_TYPE*>(qkv_ptr),

File diff suppressed because it is too large Load Diff

View File

@@ -157,7 +157,9 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
load_q_global_smem<GROUP_SIZE, num_frags_x, num_frags_y, HEAD_DIM, T>(
q_base_ptr,
&qo_smem,
@@ -410,6 +412,9 @@ __global__ void multi_query_append_attention_kernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T,
@@ -554,6 +559,10 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t q_end =
min(q_len, div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE));
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
load_q_global_smem_multi_warps<GROUP_SIZE,
num_frags_x,
num_frags_y,
@@ -819,6 +828,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T,
@@ -933,8 +945,12 @@ void MultiQueryAppendAttention(
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
}
nosplit_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
nosplit_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
@@ -996,7 +1012,12 @@ void MultiQueryAppendAttention(
num_chunks * num_heads));
}
split_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
split_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
@@ -1037,79 +1058,89 @@ void MultiQueryAppendAttention(
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
auto *kernelFn = merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); // 128k is too large
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
auto *kernelFn = merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
}
} else {
@@ -1177,8 +1208,12 @@ void MultiQueryAppendAttention(
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
}
nosplit_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
nosplit_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
@@ -1254,7 +1289,12 @@ void MultiQueryAppendAttention(
num_chunks * num_heads));
}
}
split_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
split_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k.data<T>())),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v.data<T>())),
@@ -1299,78 +1339,88 @@ void MultiQueryAppendAttention(
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
auto *kernelFn = merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
auto *kernelFn = merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
}
}

View File

@@ -31,7 +31,7 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true>
__global__ void multi_query_append_attention_c4_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
@@ -41,7 +41,7 @@ __global__ void multi_query_append_attention_c4_kernel(
const T *__restrict__ cache_v_zero_point, // [num_kv_heads, head_dim]
const T *__restrict__ shift_bias, // [q_num_heads * HEAD_DIM]
const T *__restrict__ smooth_weight, // [q_num_heads * HEAD_DIM]
const T *__restrict__ sinks, // [q_num_heads]
const T *__restrict__ sinks, // [q_num_heads]
const int *__restrict__ seq_lens,
const int *__restrict__ seq_lens_kv,
const int *__restrict__ batch_ids,
@@ -87,8 +87,8 @@ __global__ void multi_query_append_attention_c4_kernel(
block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
// When cudagraph capture prefill, may launch more gridDim.x
if (btid >= static_cast<uint32_t>(num_blocks_x_cpu)) {
return;
}
@@ -125,6 +125,9 @@ __global__ void multi_query_append_attention_c4_kernel(
float o_frag[num_frags_x][num_frags_y][8];
float m_frag[num_frags_x][2];
float d_frag[num_frags_x][2];
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM;
const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM;
@@ -180,7 +183,8 @@ __global__ void multi_query_append_attention_c4_kernel(
} else {
o_base_ptr_int8 = out + o_offset;
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq =
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
@@ -241,7 +245,6 @@ __global__ void multi_query_append_attention_c4_kernel(
v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) +
num_frags_z * 16 * HEAD_DIM / 2 * sizeof(CacheT));
const uint32_t num_iterations = div_up(
CAUSAL
? (min(chunk_len,
@@ -252,12 +255,13 @@ __global__ void multi_query_append_attention_c4_kernel(
: chunk_len,
num_frags_z * 16);
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len +
tile_id * num_rows_per_block / GROUP_SIZE,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
: mask_offset ? 0
: chunk_len) /
(num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -270,9 +274,7 @@ __global__ void multi_query_append_attention_c4_kernel(
uint32_t k_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
wid * 8 + tid / 4,
tid %
4);
wid * 8 + tid / 4, tid % 4);
uint32_t v_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
wid * 16 + tid / 2, tid % 2); // 2 * 128 / 8 = 32B, 64 nums
@@ -417,15 +419,19 @@ __global__ void multi_query_append_attention_c4_kernel(
if constexpr (!partition_kv) {
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
const uint32_t h_offset =
(q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) %
GROUP_SIZE;
current_sinks[fx][j] =
static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
normalize_d<num_frags_x, num_frags_y>(
o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
@@ -497,6 +503,9 @@ __global__ void multi_query_append_attention_c4_kernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T,
@@ -515,7 +524,7 @@ template <typename T,
typename OutT = T,
bool ENABLE_PREFILL = true>
__global__ void multi_query_append_attention_c4_warp1_4_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
CacheT *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
// head_dim]
CacheT *__restrict__ cache_v,
@@ -533,7 +542,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const int *__restrict__ cu_seqlens_q,
const int *__restrict__ block_table, // [bsz, block_num_per_seq]
const int *__restrict__ mask_offset,
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask
const int max_seq_len,
const int max_dec_len,
const int max_block_num_per_seq,
@@ -573,8 +582,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
const uint32_t num_rows_per_block = num_frags_x * 16;
const int *block_table_now = block_table + batch_id * max_block_num_per_seq;
//When cudagraph capture prefill, may launch more gridDim.x
if(btid >= static_cast<uint32_t>(num_blocks_x_cpu)){
// When cudagraph capture prefill, may launch more gridDim.x
if (btid >= static_cast<uint32_t>(num_blocks_x_cpu)) {
return;
}
@@ -612,6 +621,9 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
float m_frag[num_frags_x][2];
float d_frag[num_frags_x][2];
init_states<T, num_frags_x, num_frags_y>(o_frag, m_frag, d_frag);
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
const T *cache_k_scale_now = cache_k_scale + kv_head_idx * HEAD_DIM;
const T *cache_k_zp_now = cache_k_zero_point + kv_head_idx * HEAD_DIM;
@@ -664,11 +676,12 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
tid % 8 * num_elems_per_128b<T>();
}
}
const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
const int *mask_offset_this_seq =
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
smem_t qo_smem(smem);
uint32_t q_smem_offset_r = smem_t::get_permuted_offset<num_vecs_per_head>(
tid % 16, tid / 16);
uint32_t q_smem_offset_r =
smem_t::get_permuted_offset<num_vecs_per_head>(tid % 16, tid / 16);
load_q_global_smem_multi_warps<GROUP_SIZE,
num_frags_x,
num_frags_y,
@@ -738,11 +751,10 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
: chunk_len,
NUM_WARP_KV * num_frags_z * 16);
const uint32_t mask_check_iteration =
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(
kv_len - q_len,
chunk_start)))
: mask_offset ? 0 : chunk_len) /
(CAUSAL ? (min(chunk_len,
sub_if_greater_or_zero(kv_len - q_len, chunk_start)))
: mask_offset ? 0
: chunk_len) /
(NUM_WARP_KV * num_frags_z * 16);
uint32_t k_smem_offset_r =
@@ -755,9 +767,7 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
uint32_t k_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_head_k, inv_k_stride>(
wid * 8 + tid / 4,
tid %
4);
wid * 8 + tid / 4, tid % 4);
uint32_t v_smem_offset_w =
smem_t::get_permuted_offset<num_vecs_per_blocksize, inv_v_stride>(
wid * 16 + tid / 2, tid % 2);
@@ -824,16 +834,18 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
NUM_WARPS,
num_frags_x,
num_frags_y,
num_frags_z>(attn_mask ? attn_mask + batch_id * attn_mask_len *attn_mask_len : nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq,
sliding_window);
num_frags_z>(
attn_mask ? attn_mask + batch_id * attn_mask_len * attn_mask_len
: nullptr,
q_base_seq_id_this_block,
kv_idx_base + wid * num_frags_z * 16,
q_len,
kv_len,
chunk_end,
attn_mask_len,
s_frag,
mask_offset_this_seq,
sliding_window);
}
update_mdo_states<num_frags_x, num_frags_y, num_frags_z>(
@@ -903,15 +915,19 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
if (num_chunks_this_seq <= 1) {
if (sinks) {
float current_sinks[num_frags_x][2];
#pragma unroll
#pragma unroll
for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
#pragma unroll
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
const uint32_t h_offset = (q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) % GROUP_SIZE;
current_sinks[fx][j] = static_cast<float>(sinks[q_head_idx + h_offset]);
const uint32_t h_offset =
(q_base_seq_id_this_block + fx * 16 + tid / 4 + 8 * j) %
GROUP_SIZE;
current_sinks[fx][j] =
static_cast<float>(sinks[q_head_idx + h_offset]);
}
}
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag, m_frag, current_sinks);
normalize_d<num_frags_x, num_frags_y>(
o_frag, d_frag, m_frag, current_sinks);
} else {
normalize_d<num_frags_x, num_frags_y>(o_frag, d_frag);
}
@@ -987,6 +1003,9 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T,
@@ -1119,7 +1138,12 @@ void MultiQueryAppendC4Attention(
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
}
nosplit_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
nosplit_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
const_cast<uint8_t *>(cache_k.data<uint8_t>()),
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
@@ -1138,8 +1162,8 @@ void MultiQueryAppendC4Attention(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1188,7 +1212,12 @@ void MultiQueryAppendC4Attention(
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
split_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
split_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
const_cast<uint8_t *>(cache_k.data<uint8_t>()),
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
@@ -1207,8 +1236,8 @@ void MultiQueryAppendC4Attention(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1238,79 +1267,86 @@ void MultiQueryAppendC4Attention(
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
launchWithPdlWhenEnabled(
merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
launchWithPdlWhenEnabled(
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
}
} else {
@@ -1353,7 +1389,6 @@ void MultiQueryAppendC4Attention(
const float ratio = static_cast<float>(num_blocks_need) /
static_cast<float>(num_blocks_per_wave);
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
@@ -1362,9 +1397,9 @@ void MultiQueryAppendC4Attention(
const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len;
if (attn_mask) {
attn_mask_len = attn_mask.get().shape()[1];
attn_mask_len = attn_mask.get().shape()[1];
} else {
attn_mask_len = -1;
attn_mask_len = -1;
}
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
@@ -1391,7 +1426,12 @@ void MultiQueryAppendC4Attention(
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
}
nosplit_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
nosplit_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
const_cast<uint8_t *>(cache_k.data<uint8_t>()),
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
@@ -1410,8 +1450,8 @@ void MultiQueryAppendC4Attention(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1420,7 +1460,7 @@ void MultiQueryAppendC4Attention(
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1476,27 +1516,32 @@ void MultiQueryAppendC4Attention(
num_chunks * num_heads));
}
}
split_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
split_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
const_cast<uint8_t *>(cache_k.data<uint8_t>()),
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_k_scale.data<T>())),
cache_k_zp ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_k_zp.get().data<T>()))
: nullptr,
const_cast<T *>(cache_k_zp.get().data<T>()))
: nullptr,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(cache_v_scale.data<T>())),
cache_v_zp ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(cache_v_zp.get().data<T>()))
: nullptr,
const_cast<T *>(cache_v_zp.get().data<T>()))
: nullptr,
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
batch_ids.data<int>(),
@@ -1505,7 +1550,7 @@ void MultiQueryAppendC4Attention(
block_table.data<int>(),
meta_data.mask_offset,
attn_mask ? const_cast<bool *>(attn_mask.get().data<bool>())
: nullptr,
: nullptr,
max_seq_len,
max_dec_len,
max_block_num_per_seq,
@@ -1529,79 +1574,86 @@ void MultiQueryAppendC4Attention(
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
launchWithPdlWhenEnabled(
merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads);
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
launchWithPdlWhenEnabled(
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
}
}

View File

@@ -1271,8 +1271,12 @@ void MultiQueryAppendC8Attention(
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
}
nosplit_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
nosplit_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
const_cast<uint8_t *>(cache_k.data<uint8_t>()),
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
@@ -1335,7 +1339,12 @@ void MultiQueryAppendC8Attention(
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
split_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
split_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
const_cast<uint8_t *>(cache_k.data<uint8_t>()),
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
@@ -1379,78 +1388,86 @@ void MultiQueryAppendC8Attention(
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
launchWithPdlWhenEnabled(
merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
launchWithPdlWhenEnabled(
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
}
} else {
@@ -1568,8 +1585,12 @@ void MultiQueryAppendC8Attention(
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
}
nosplit_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
nosplit_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
const_cast<uint8_t *>(cache_k.data<uint8_t>()),
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
@@ -1648,7 +1669,12 @@ void MultiQueryAppendC8Attention(
num_chunks * num_heads));
}
}
split_kv_kernel<<<grids, blocks, smem_size, stream>>>(
launchWithPdlWhenEnabled(
split_kv_kernel,
grids,
blocks,
smem_size,
stream,
reinterpret_cast<NV_TYPE *>(const_cast<T *>(qkv.data<T>())),
const_cast<uint8_t *>(cache_k.data<uint8_t>()),
const_cast<uint8_t *>(cache_v.data<uint8_t>()),
@@ -1695,73 +1721,87 @@ void MultiQueryAppendC8Attention(
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(bsz, num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_decoder_kernel<NV_TYPE, vec_size, blocky, HEAD_DIM>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
auto *kernelFn = merge_multi_chunks_decoder_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM);
} else {
constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>
<<<grids_merge, blocks_merge, 0, stream>>>(
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(const_cast<T *>(
smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
launchWithPdlWhenEnabled(
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
}
}

View File

@@ -16,10 +16,10 @@
#include <nvml.h>
float bfloat16_to_float(__nv_bfloat16 x) {
uint32_t tmp_x = *(reinterpret_cast<uint16_t*>(&x));
tmp_x = tmp_x << 16;
float float_x = *(reinterpret_cast<float*>(&tmp_x));
return float_x;
uint32_t tmp_x = *(reinterpret_cast<uint16_t*>(&x));
tmp_x = tmp_x << 16;
float float_x = *(reinterpret_cast<float*>(&tmp_x));
return float_x;
}
template <typename T>
@@ -27,120 +27,132 @@ static void PrintMatrix(const T* mat_d,
int num,
std::string name,
int numOfCols) {
std::vector<T> tmp(num);
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
std::vector<T> tmp(num);
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
std::ofstream outfile;
outfile.open(name + ".dtxt", std::ios::out | std::ios::app);
std::stringstream ss;
std::ofstream outfile;
outfile.open(name + ".dtxt", std::ios::out | std::ios::app);
std::stringstream ss;
for (int i = 0; i < num; ++i) {
if (std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value ||
std::is_same<T, int32_t>::value) {
ss << static_cast<int>(tmp[i]) << " ";
} else {
ss << std::setprecision(8) << static_cast<float>(tmp[i]) << " ";
}
if (i % numOfCols == numOfCols - 1) {
ss << std::endl;
}
for (int i = 0; i < num; ++i) {
if (std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value ||
std::is_same<T, int32_t>::value) {
ss << static_cast<int>(tmp[i]) << " ";
} else {
ss << std::setprecision(8) << static_cast<float>(tmp[i]) << " ";
}
outfile << ss.str();
outfile.close();
if (i % numOfCols == numOfCols - 1) {
ss << std::endl;
}
}
outfile << ss.str();
outfile.close();
}
GPUMemoryChecker::GPUMemoryChecker() {
nvmlReturn_t result = nvmlInit_v2();
if (NVML_SUCCESS != result) {
throw std::runtime_error("Failed to initialize NVML: " +
std::string(nvmlErrorString(result)));
}
nvmlReturn_t result = nvmlInit_v2();
if (NVML_SUCCESS != result) {
throw std::runtime_error("Failed to initialize NVML: " +
std::string(nvmlErrorString(result)));
}
result = nvmlDeviceGetCount_v2(&deviceCount_);
if (NVML_SUCCESS != result) {
nvmlShutdown();
throw std::runtime_error("Failed to get GPU count: " +
std::string(nvmlErrorString(result)));
}
getCUDAVisibleDevice();
}
GPUMemoryChecker::~GPUMemoryChecker() {
result = nvmlDeviceGetCount_v2(&deviceCount_);
if (NVML_SUCCESS != result) {
nvmlShutdown();
throw std::runtime_error("Failed to get GPU count: " +
std::string(nvmlErrorString(result)));
}
getCUDAVisibleDevice();
}
void GPUMemoryChecker::getCUDAVisibleDevice(){
std::vector<int> devices;
const char* env_p = std::getenv("CUDA_VISIBLE_DEVICES");
if(!env_p){
for(int i = 0; i < deviceCount_; i++){
visible_device_.push_back(i);
return ;
}
}
GPUMemoryChecker::~GPUMemoryChecker() { nvmlShutdown(); }
std::string env_str(env_p);
std::istringstream stream(env_str);
std::string device_id;
while(std::getline(stream, device_id, ',')){
visible_device_.push_back(std::stoi(device_id));
visible_device_mem_usage_.push_back(-1);
void GPUMemoryChecker::getCUDAVisibleDevice() {
std::vector<int> devices;
const char* env_p = std::getenv("CUDA_VISIBLE_DEVICES");
if (!env_p) {
for (int i = 0; i < deviceCount_; i++) {
visible_device_.push_back(i);
return;
}
std::cout << "\nVisible NVIDIA GPU devices" << env_str << std::endl;
return ;
}
std::string env_str(env_p);
std::istringstream stream(env_str);
std::string device_id;
while (std::getline(stream, device_id, ',')) {
visible_device_.push_back(std::stoi(device_id));
visible_device_mem_usage_.push_back(-1);
}
std::cout << "\nVisible NVIDIA GPU devices" << env_str << std::endl;
return;
}
void GPUMemoryChecker::addCheckPoint(const char* call_file, int call_line) {
try {
try {
for (int i = 0; i < visible_device_.size(); i++) {
unsigned int device_id = visible_device_.at(i);
nvmlDevice_t device;
nvmlReturn_t result = nvmlDeviceGetHandleByIndex_v2(device_id, &device);
if (NVML_SUCCESS != result) {
std::cerr << "Failed to get handle for GPU " << device_id << ": "
<< nvmlErrorString(result) << std::endl;
continue;
}
char name[NVML_DEVICE_NAME_BUFFER_SIZE];
result = nvmlDeviceGetName(device, name, NVML_DEVICE_NAME_BUFFER_SIZE);
if (NVML_SUCCESS != result) {
std::cerr << "Failed to get name for GPU " << device_id << ": "
<< nvmlErrorString(result) << std::endl;
continue;
}
for (int i = 0; i < visible_device_.size(); i++) {
unsigned int device_id = visible_device_.at(i);
nvmlDevice_t device;
nvmlReturn_t result = nvmlDeviceGetHandleByIndex_v2(device_id, &device);
if (NVML_SUCCESS != result) {
std::cerr << "Failed to get handle for GPU " << device_id << ": "
<< nvmlErrorString(result) << std::endl;
continue;
}
nvmlMemory_t memoryInfo;
result = nvmlDeviceGetMemoryInfo(device, &memoryInfo);
if (NVML_SUCCESS != result) {
std::cerr << "Failed to get memory info for GPU " << device_id << ": "
<< nvmlErrorString(result) << std::endl;
continue;
}
char name[NVML_DEVICE_NAME_BUFFER_SIZE];
result = nvmlDeviceGetName(device, name, NVML_DEVICE_NAME_BUFFER_SIZE);
if (NVML_SUCCESS != result) {
std::cerr << "Failed to get name for GPU " << device_id << ": "
<< nvmlErrorString(result) << std::endl;
continue;
}
nvmlMemory_t memoryInfo;
result = nvmlDeviceGetMemoryInfo(device, &memoryInfo);
if (NVML_SUCCESS != result) {
std::cerr << "Failed to get memory info for GPU " << device_id << ": "
<< nvmlErrorString(result) << std::endl;
continue;
}
// Check GPU memory
const char* env_c = std::getenv("MEMCHECKER_CHECK_MEMORY");
if (env_c){
assert(memoryInfo.used <= visible_device_mem_usage_.at(i) && "GPU Memory does not allow growth!");
}
visible_device_mem_usage_[i] = memoryInfo.used;
}
// Check GPU memory
const char* env_p = std::getenv("MEMCHECKER_PRINT_MEMORY");
if (env_p){
std::cout << "\nCall Line: "<< call_line << "\t";
for (int i = 0; i < visible_device_.size(); i++) {
unsigned int device_id = visible_device_.at(i);
std::cout << "GPU " << device_id << ": "
<< " Used memory: " << visible_device_mem_usage_.at(device_id) / (1024 * 1024) << " MB\t";
}
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
// Check GPU memory
const char* env_c = std::getenv("MEMCHECKER_CHECK_MEMORY");
if (env_c) {
assert(memoryInfo.used <= visible_device_mem_usage_.at(i) &&
"GPU Memory does not allow growth!");
}
visible_device_mem_usage_[i] = memoryInfo.used;
}
// Check GPU memory
const char* env_p = std::getenv("MEMCHECKER_PRINT_MEMORY");
if (env_p) {
std::cout << "\nCall Line: " << call_line << "\t";
for (int i = 0; i < visible_device_.size(); i++) {
unsigned int device_id = visible_device_.at(i);
std::cout << "GPU " << device_id << ": "
<< " Used memory: "
<< visible_device_mem_usage_.at(device_id) / (1024 * 1024)
<< " MB\t";
}
}
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
}
}
bool getEnvEnablePDL() {
static std::once_flag flag;
static bool enablePDL = false;
std::call_once(flag, [&]() {
int sm_version = GetSMVersion();
if (sm_version >= 90) {
enablePDL = getBoolEnv("FD_ENABLE_PDL");
}
});
return enablePDL;
}

View File

@@ -20,6 +20,7 @@
#include "glog/logging.h"
#endif
#include <fcntl.h>
#include <nvml.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -27,19 +28,17 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <nvml.h>
#include <cassert>
#include <cstdlib>
#include <cstdlib>
#include <cstring>
#ifdef PADDLE_WITH_HIP
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
#include <hiprand.h>
#include <hiprand_kernel.h>
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#else
#include <cub/cub.cuh>
@@ -58,8 +57,8 @@ namespace cub = hipcub;
#else
#include "paddle/phi/core/cuda_stream.h"
#endif
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/dense_tensor.h"
#ifdef PADDLE_WITH_COREX
#define WARP_SIZE 64
@@ -74,14 +73,16 @@ namespace cub = hipcub;
using json = nlohmann::json;
#endif
#define CUDA_CHECK(call) \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) { \
std::printf("at %s:%d - %s.\n", __FILE__, __LINE__, \
cudaGetErrorString(error_code)); \
exit(1); \
} \
#define CUDA_CHECK(call) \
do { \
const cudaError_t error_code = call; \
if (error_code != cudaSuccess) { \
std::printf("at %s:%d - %s.\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString(error_code)); \
exit(1); \
} \
} while (0)
#ifdef PADDLE_WITH_HIP
@@ -110,9 +111,10 @@ inline hipError_t GetNumBlocks(int64_t n, int *num_blocks) {
return err;
}
}
*num_blocks = std::max<int>(
1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
*num_blocks =
std::max<int>(1,
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return hipSuccess;
}
#else
@@ -141,9 +143,10 @@ inline cudaError_t GetNumBlocks(int64_t n, int *num_blocks) {
return err;
}
}
*num_blocks = std::max<int>(
1, std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
*num_blocks =
std::max<int>(1,
std::min<int64_t>((n + kBlockSize - 1) / kBlockSize,
sm_count * tpm / kBlockSize * kNumWaves));
return cudaSuccess;
}
@@ -163,51 +166,54 @@ inline int GetGPUComputeCapability(int id) {
#endif
#ifndef DISPATCH_FLOAT_FP6_DTYPE
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
switch (pd_dtype) { \
case phi::DataType::FLOAT32: { \
using c_type = float; \
__VA_ARGS__ \
break; \
} \
case phi::DataType::BFLOAT16: { \
using c_type = phi::dtype::bfloat16; \
__VA_ARGS__ \
break; \
} \
case phi::DataType::FLOAT16: { \
using c_type = phi::dtype::float16; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
} \
}
#define DISPATCH_FLOAT_FP6_DTYPE(pd_dtype, c_type, ...) \
switch (pd_dtype) { \
case phi::DataType::FLOAT32: { \
using c_type = float; \
__VA_ARGS__ \
break; \
} \
case phi::DataType::BFLOAT16: { \
using c_type = phi::dtype::bfloat16; \
__VA_ARGS__ \
break; \
} \
case phi::DataType::FLOAT16: { \
using c_type = phi::dtype::float16; \
__VA_ARGS__ \
break; \
} \
default: { \
PD_THROW("Only supported attr of input type in [fp32, fp16, bf16]."); \
} \
}
#endif
inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1)
return num;
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
template <paddle::DataType D> class PDTraits;
template <paddle::DataType D>
class PDTraits;
template <> class PDTraits<paddle::DataType::FLOAT32> {
public:
template <>
class PDTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};
template <> class PDTraits<paddle::DataType::FLOAT16> {
public:
template <>
class PDTraits<paddle::DataType::FLOAT16> {
public:
typedef half DataType;
typedef paddle::float16 data_t;
};
template <> class PDTraits<paddle::DataType::BFLOAT16> {
public:
template <>
class PDTraits<paddle::DataType::BFLOAT16> {
public:
#ifdef PADDLE_WITH_HIP
typedef hip_bfloat16 DataType;
#else
@@ -216,27 +222,31 @@ public:
typedef paddle::bfloat16 data_t;
};
template <> class PDTraits<paddle::DataType::INT8> {
public:
template <>
class PDTraits<paddle::DataType::INT8> {
public:
typedef int8_t DataType;
typedef int8_t data_t;
};
template <> class PDTraits<paddle::DataType::UINT8> {
public:
template <>
class PDTraits<paddle::DataType::UINT8> {
public:
typedef uint8_t DataType;
typedef uint8_t data_t;
};
#ifndef PADDLE_WITH_COREX
template <> class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
public:
template <>
class PDTraits<paddle::DataType::FLOAT8_E4M3FN> {
public:
typedef __nv_fp8_e4m3 DataType;
typedef paddle::float8_e4m3fn data_t;
};
#endif
template <typename T, int Size> struct alignas(sizeof(T) * Size) AlignedVector {
template <typename T, int Size>
struct alignas(sizeof(T) * Size) AlignedVector {
T val[Size];
HOSTDEVICE inline const T &operator[](int i) const { return val[i]; }
@@ -261,7 +271,7 @@ HOSTDEVICE inline void Store(const AlignedVector<T, Size> &vec, T *addr) {
template <int Size>
HOSTDEVICE inline void Store(const AlignedVector<hip_bfloat16, Size> &vec,
int8_t *addr) {
printf("Error: Store hip_bfloat16 to int8_t is not supported!");
printf("Error: Store hip_bfloat16 to int8_t is not supported!");
}
#else
template <int Size>
@@ -279,11 +289,13 @@ HOSTDEVICE inline void Store(const AlignedVector<half, Size> &vec,
constexpr int VEC_16B = 16;
template <typename T> __device__ T max_func(const T a, const T b) {
template <typename T>
__device__ T max_func(const T a, const T b) {
return a > b ? a : b;
}
template <typename T> struct MaxOp {
template <typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
return max_func(a, b);
}
@@ -316,14 +328,14 @@ inline json readJsonFromFile(const std::string &filePath) {
}
#endif
#define cudaCheckError() \
{ \
cudaError_t e = cudaGetLastError(); \
if (e != cudaSuccess) { \
std::cerr << "CUDA Error " << __FILE__ << ":" << __LINE__ << ": " \
<< cudaGetErrorString(e) << std::endl; \
exit(EXIT_FAILURE); \
} \
#define cudaCheckError() \
{ \
cudaError_t e = cudaGetLastError(); \
if (e != cudaSuccess) { \
std::cerr << "CUDA Error " << __FILE__ << ":" << __LINE__ << ": " \
<< cudaGetErrorString(e) << std::endl; \
exit(EXIT_FAILURE); \
} \
}
// place must be an existing place object and cannot use paddle::CPUPlace() or
@@ -336,8 +348,8 @@ inline paddle::Tensor GetEmptyTensor(const common::DDim &dims,
auto *allocator = paddle::GetAllocator(place);
phi::DenseTensor dense_tensor;
dense_tensor.Resize(dims);
dense_tensor.AllocateFrom(allocator, dtype,
dense_tensor.numel() * phi::SizeOf(dtype));
dense_tensor.AllocateFrom(
allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype));
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
}
@@ -348,39 +360,63 @@ inline paddle::Tensor GetEmptyTensor(const common::DDim &dims,
auto *allocator = paddle::GetAllocator(place);
phi::DenseTensor dense_tensor;
dense_tensor.Resize(dims);
dense_tensor.AllocateFrom(allocator, dtype,
dense_tensor.numel() * phi::SizeOf(dtype));
dense_tensor.AllocateFrom(
allocator, dtype, dense_tensor.numel() * phi::SizeOf(dtype));
dense_tensor.set_strides(strides);
return paddle::Tensor(std::make_shared<phi::DenseTensor>(dense_tensor));
}
#endif
__global__ void free_and_dispatch_block(
bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder,
int *block_tables, int *encoder_block_lens, bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len, int *recover_block_list, int *recover_len,
int *need_block_list, int *need_block_len, int *used_list_len,
int *free_list, int *free_list_len, int64_t *first_token_ids, const int bsz,
const int block_size, const int block_num_per_seq,
const int max_decoder_block_num);
__global__ void free_and_dispatch_block(bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num);
__global__ void speculate_free_and_dispatch_block(
bool *stop_flags, int *seq_lens_this_time, int *seq_lens_decoder,
int *block_tables, int *encoder_block_lens, bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len, int *recover_block_list, int *recover_len,
int *need_block_list, int *need_block_len, int *used_list_len,
int *free_list, int *free_list_len, int64_t *first_token_ids,
int *accept_num, const int bsz, const int block_size,
const int block_num_per_seq, const int max_decoder_block_num,
bool *stop_flags,
int *seq_lens_this_time,
int *seq_lens_decoder,
int *block_tables,
int *encoder_block_lens,
bool *is_block_step,
int *step_block_list, // [bsz]
int *step_len,
int *recover_block_list,
int *recover_len,
int *need_block_list,
int *need_block_len,
int *used_list_len,
int *free_list,
int *free_list_len,
int64_t *first_token_ids,
int *accept_num,
const int bsz,
const int block_size,
const int block_num_per_seq,
const int max_decoder_block_num,
const int max_draft_tokens);
__device__ bool speculate_free_and_dispatch_block(const int &qid,
int *need_block_list,
const int &need_block_len);
static std::string global_base64_chars = // NOLINT
static std::string global_base64_chars = // NOLINT
"Tokp9lA/BjimRVKx32edMPFftOzsbNQ8C15Xn+YUEGc4WD0uLIq7hyJ6vZaHSwrg";
// Base64 编码函数
@@ -501,7 +537,8 @@ inline T get_relative_best(nlohmann::json *json_data,
}
#endif
__device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
__device__ inline bool is_in_end(const int64_t id,
const int64_t *end_ids,
int length) {
bool flag = false;
for (int i = 0; i < length; i++) {
@@ -512,22 +549,20 @@ __device__ inline bool is_in_end(const int64_t id, const int64_t *end_ids,
return flag;
}
template <typename T> inline __device__ __host__ T div_up(T m, T n) {
template <typename T>
inline __device__ __host__ T div_up(T m, T n) {
return (m + n - 1) / n;
}
template <typename T>
__device__ __inline__ T ClipFunc(const T v, const T min, const T max) {
if (v > max)
return max;
if (v < min)
return min;
if (v > max) return max;
if (v < min) return min;
return v;
}
template <typename T>
static void PrintMatrix3(const T *mat_d, int num, std::string name) {
std::vector<T> tmp(num);
#ifdef PADDLE_WITH_HIP
hipMemcpy(tmp.data(), mat_d, sizeof(T) * num, hipMemcpyDeviceToHost);
@@ -535,7 +570,6 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
cudaMemcpy(tmp.data(), mat_d, sizeof(T) * num, cudaMemcpyDeviceToHost);
#endif
std::ofstream outfile;
outfile.open(name + ".txt", std::ios::out);
std::stringstream ss;
@@ -544,7 +578,7 @@ static void PrintMatrix3(const T *mat_d, int num, std::string name) {
if (std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value) {
ss << static_cast<int>(tmp[i]) << std::endl;
} else {
ss << std::setprecision(8) << (float)(tmp[i]) << std::endl; // NOLINT
ss << std::setprecision(8) << (float)(tmp[i]) << std::endl; // NOLINT
}
}
outfile << ss.str();
@@ -573,7 +607,8 @@ __forceinline__ __device__ uint32_t ld_flag_acquire(uint32_t *flag_addr,
}
__forceinline__ __device__ void st_flag_release(uint32_t *flag_addr,
uint32_t flag, int mode = 0) {
uint32_t flag,
int mode = 0) {
if (mode == 0) {
asm volatile("st.release.sys.global.b32 [%1], %0;" ::"r"(flag),
"l"(flag_addr));
@@ -589,7 +624,8 @@ __forceinline__ __device__ void st_flag_release(uint32_t *flag_addr,
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0;
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device);
return max_shared_mem_per_block_opt_in;
}
#endif
@@ -627,29 +663,29 @@ inline bool checkAttentionBackend() {
#ifndef GPU_MEMORY_CHECKER_H
#define GPU_MEMORY_CHECKER_H
class GPUMemoryChecker {
public:
static GPUMemoryChecker* getInstance() {
static GPUMemoryChecker instance;
return &instance;
}
public:
static GPUMemoryChecker *getInstance() {
static GPUMemoryChecker instance;
return &instance;
}
void addCheckPoint(const char* call_file, int call_line);
unsigned int getGPUCount() const { return deviceCount_; }
void getCUDAVisibleDevice();
void addCheckPoint(const char *call_file, int call_line);
unsigned int getGPUCount() const { return deviceCount_; }
void getCUDAVisibleDevice();
GPUMemoryChecker(const GPUMemoryChecker&) = delete;
void operator=(const GPUMemoryChecker&) = delete;
GPUMemoryChecker(const GPUMemoryChecker &) = delete;
void operator=(const GPUMemoryChecker &) = delete;
private:
GPUMemoryChecker();
~GPUMemoryChecker();
private:
GPUMemoryChecker();
~GPUMemoryChecker();
unsigned int deviceCount_;
std::vector<unsigned int> visible_device_;
std::vector<unsigned int> visible_device_mem_usage_;
unsigned int deviceCount_;
std::vector<unsigned int> visible_device_;
std::vector<unsigned int> visible_device_mem_usage_;
};
#endif // GPU_MEMORY_CHECKER_H
#endif // GPU_MEMORY_CHECKER_H
__device__ __forceinline__ float warpReduceMax(float value) {
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 16));
value = fmaxf(value, __shfl_xor_sync(0xffffffff, value, 8));
@@ -674,3 +710,31 @@ __device__ __forceinline__ float blockReduceMax(float value) {
return value;
}
inline bool getBoolEnv(char const *name) {
char const *env = std::getenv(name);
return env && env[0] == '1' && env[1] == '\0';
}
bool getEnvEnablePDL();
template <typename KernelFn, typename... Args>
inline void launchWithPdlWhenEnabled(KernelFn kernelFn,
dim3 grid,
dim3 block,
size_t dynamicShmSize,
cudaStream_t stream,
Args &&...args) {
cudaLaunchConfig_t kernelConfig;
kernelConfig.gridDim = grid;
kernelConfig.blockDim = block;
kernelConfig.dynamicSmemBytes = dynamicShmSize;
kernelConfig.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = getEnvEnablePDL();
kernelConfig.attrs = attrs;
kernelConfig.numAttrs = 1;
cudaLaunchKernelEx(&kernelConfig, kernelFn, std::forward<Args>(args)...);
}

View File

@@ -448,6 +448,7 @@ class LLMEngine:
"NCCL_ALGO": "Ring",
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
"OMP_NUM_THREADS": int(os.getenv("OMP_NUM_THREADS", 3)),
"FD_ENABLE_PDL": envs.FD_ENABLE_PDL,
}
# environment variables needed by Dy2St
variables.update(

View File

@@ -159,6 +159,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
"FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")),
"FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")),
"FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")),
"FD_ENABLE_PDL": lambda: int(os.getenv("FD_ENABLE_PDL", "1")),
}