[Feature][OP] Append Attn Support CUDA-PDL (#5072)

This commit is contained in:
chen
2025-11-17 20:47:33 +08:00
committed by GitHub
parent c2c1942db9
commit d58c1db8a0
12 changed files with 2828 additions and 2068 deletions

View File

@@ -2296,6 +2296,9 @@ __global__ void merge_multi_chunks_decoder_kernel(
const int bid = blockIdx.x, hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
const int start_token_idx = cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) return;
@@ -2332,6 +2335,10 @@ __global__ void merge_multi_chunks_decoder_kernel(
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
m = -3.0e+30f;
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
#pragma unroll 2
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset = (bid * num_chunks + i) * num_heads + hid;
@@ -2397,6 +2404,9 @@ __global__ void merge_multi_chunks_decoder_kernel(
out_vec,
&out[(start_token_idx * num_heads + hid) * head_dim + vid * vec_size]);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
template <typename T,
@@ -2433,6 +2443,9 @@ __global__ void merge_multi_chunks_v2_kernel(
const int hid = blockIdx.y;
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t bid = batch_id_per_token[qid];
if (bid == -1) {
@@ -2569,4 +2582,7 @@ __global__ void merge_multi_chunks_v2_kernel(
}
__syncthreads();
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
}