mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Feature][OP] Append Attn Support CUDA-PDL (#5072)
This commit is contained in:
@@ -2296,6 +2296,9 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
const int bid = blockIdx.x, hid = blockIdx.y;
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
const int start_token_idx = cu_seqlens_q[bid];
|
||||
const int seq_len_q = seq_lens_q[bid];
|
||||
if (seq_len_q == 0) return;
|
||||
@@ -2332,6 +2335,10 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
} else if constexpr (std::is_same<T, __nv_bfloat16>::value) {
|
||||
m = -3.0e+30f;
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
#pragma unroll 2
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset = (bid * num_chunks + i) * num_heads + hid;
|
||||
@@ -2397,6 +2404,9 @@ __global__ void merge_multi_chunks_decoder_kernel(
|
||||
out_vec,
|
||||
&out[(start_token_idx * num_heads + hid) * head_dim + vid * vec_size]);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
@@ -2433,6 +2443,9 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
const int hid = blockIdx.y;
|
||||
__shared__ T smem[bdy * HEAD_DIM];
|
||||
__shared__ float md_smem[bdy * 2];
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||
const uint32_t bid = batch_id_per_token[qid];
|
||||
if (bid == -1) {
|
||||
@@ -2569,4 +2582,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user