supports mtp split_kv_attn (#5343)

This commit is contained in:
lzy
2025-12-03 12:40:16 +08:00
committed by GitHub
parent dfeabee123
commit c71a44c7e5
4 changed files with 212 additions and 531 deletions

View File

@@ -2451,7 +2451,6 @@ __global__ void merge_multi_chunks_v2_kernel(
if (bid == -1) { if (bid == -1) {
continue; continue;
} }
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid]; const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) continue; if (seq_len_q == 0) continue;
int seq_len_kv = seq_lens_kv[bid]; int seq_len_kv = seq_lens_kv[bid];
@@ -2470,8 +2469,6 @@ __global__ void merge_multi_chunks_v2_kernel(
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size); const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) { if (num_chunks_this_seq <= 1) {
continue; continue;
} else if (!ENABLE_PREFILL) {
continue;
} }
using LoadT = AlignedVector<T, vec_size>; using LoadT = AlignedVector<T, vec_size>;
@@ -2497,32 +2494,14 @@ __global__ void merge_multi_chunks_v2_kernel(
} }
#pragma unroll 2 #pragma unroll 2
for (int i = ty; i < num_chunks_this_seq; i += bdy) { for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset; uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
if (ENABLE_PREFILL) {
offset = (qid * num_chunks + i) * num_heads + hid;
} else {
offset =
((bid * speculate_max_draft_token_num + local_seq_id) * num_chunks +
i) *
num_heads +
hid;
}
float m_prev = m; float m_prev = m;
float d_prev = d; float d_prev = d;
const float m_now = multi_m[offset]; const float m_now = multi_m[offset];
const float d_now = multi_d[offset]; const float d_now = multi_d[offset];
m = max(m_prev, m_now); m = max(m_prev, m_now);
if (ENABLE_PREFILL) { offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
offset = vid * vec_size;
(qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
vid * vec_size;
} else {
offset = ((bid * speculate_max_draft_token_num + local_seq_id) *
num_chunks * num_heads +
i * num_heads + hid) *
head_dim +
vid * vec_size;
}
Load<T, vec_size>(&multi_out[offset], &load_vec); Load<T, vec_size>(&multi_out[offset], &load_vec);
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
const T scale1_T = static_cast<T>(scale1), const T scale1_T = static_cast<T>(scale1),

View File

@@ -134,17 +134,9 @@ __global__ void multi_query_append_attention_kernel(
T *o_base_ptr_T = nullptr; T *o_base_ptr_T = nullptr;
OutT *o_base_ptr_int8 = nullptr; OutT *o_base_ptr_int8 = nullptr;
if constexpr (partition_kv) { if constexpr (partition_kv) {
if (ENABLE_PREFILL) { o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b<T>();
tid % 8 * num_elems_per_128b<T>();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
}
} else { } else {
o_base_ptr_int8 = out + o_offset; o_base_ptr_int8 = out + o_offset;
} }
@@ -394,18 +386,8 @@ __global__ void multi_query_append_attention_kernel(
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
if (qo_idx - q_start_seq_id < q_len) { if (qo_idx - q_start_seq_id < q_len) {
uint32_t offset; uint32_t offset =
if (ENABLE_PREFILL) { (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
offset =
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
} else {
offset = ((batch_id * speculate_max_draft_token_num +
qo_idx_now / GROUP_SIZE) *
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
}
tmp_m[offset] = m_frag[fx][j]; tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j]; tmp_d[offset] = d_frag[fx][j];
} }
@@ -542,11 +524,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>(); tid % 8 * num_elems_per_128b<T>();
} else { } else {
o_base_ptr_T = o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
tmp_workspace + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + tid % 8 * num_elems_per_128b<T>();
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
} }
} }
const int *mask_offset_this_seq = const int *mask_offset_this_seq =
@@ -814,12 +794,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
qo_head_idx; qo_head_idx;
} else { } else {
offset = ((batch_id * speculate_max_draft_token_num + offset =
qo_idx_now / GROUP_SIZE) * (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
} }
tmp_m[offset] = m_frag[fx][j]; tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j]; tmp_d[offset] = d_frag[fx][j];
@@ -918,10 +894,7 @@ void MultiQueryAppendAttention(
int sm_count; int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size); uint32_t chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_dec_len, chunk_size); const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps); dim3 blocks(32, num_warps);
@@ -1053,95 +1026,51 @@ void MultiQueryAppendAttention(
sliding_window); sliding_window);
// merge // merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>(); constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) { constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx;
constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num),
dim3 grids_merge(bsz, num_heads); num_heads); // 128k is too large
dim3 blocks_merge(blockx, blocky); dim3 blocks_merge(blockx, blocky);
auto *kernelFn = merge_multi_chunks_decoder_kernel<NV_TYPE, auto *kernelFn = merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size, vec_size,
blocky, blocky,
HEAD_DIM, HEAD_DIM,
OUT_NV_TYPE, OUT_NV_TYPE,
ENABLE_PREFILL>; ENABLE_PREFILL>;
launchWithPdlWhenEnabled( launchWithPdlWhenEnabled(
kernelFn, kernelFn,
grids_merge, grids_merge,
blocks_merge, blocks_merge,
0, 0,
stream, stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()), reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()), static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()), static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(), seq_lens_q.data<int>(),
seq_lens_kv.data<int>(), seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(), seq_lens_encoder.data<int>(),
cu_seqlens_q.data<int>(), batch_id_per_token.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>( cu_seqlens_q.data<int>(),
const_cast<T *>(shift_bias.get().data<T>())) shift_bias ? reinterpret_cast<NV_TYPE *>(
: nullptr, const_cast<T *>(shift_bias.get().data<T>()))
smooth_weight ? reinterpret_cast<NV_TYPE *>( : nullptr,
const_cast<T *>(smooth_weight.get().data<T>())) smooth_weight ? reinterpret_cast<NV_TYPE *>(
: nullptr, const_cast<T *>(smooth_weight.get().data<T>()))
sinks ? reinterpret_cast<NV_TYPE *>( : nullptr,
const_cast<T *>(sinks.get().data<T>())) sinks ? reinterpret_cast<NV_TYPE *>(
: nullptr, const_cast<T *>(sinks.get().data<T>()))
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()), : nullptr,
quant_max_bound, reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_min_bound, quant_max_bound,
in_scale, quant_min_bound,
max_seq_len, in_scale,
num_chunks, max_seq_len,
num_heads, num_chunks,
chunk_size, num_heads,
HEAD_DIM); chunk_size,
} else { HEAD_DIM,
constexpr int blockx = HEAD_DIM / vec_size; token_num,
constexpr int blocky = (128 + blockx - 1) / blockx; speculate_max_draft_token_num);
dim3 grids_merge(min(sm_count * 4, token_num),
num_heads); // 128k is too large
dim3 blocks_merge(blockx, blocky);
auto *kernelFn = merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>;
launchWithPdlWhenEnabled(
kernelFn,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
} }
} else { } else {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV; constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV;
@@ -1173,9 +1102,6 @@ void MultiQueryAppendAttention(
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size); uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
uint32_t attn_mask_len; uint32_t attn_mask_len;
if (attn_mask) { if (attn_mask) {
@@ -1263,31 +1189,15 @@ void MultiQueryAppendAttention(
phi::SizeOf(paddle::DataType::FLOAT32) * phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(bsz * num_chunks * num_heads)); static_cast<size_t>(bsz * num_chunks * num_heads));
} else { } else {
if (ENABLE_PREFILL) { tmp_workspace = allocator->Allocate(
tmp_workspace = phi::SizeOf(qkv.dtype()) *
allocator->Allocate(phi::SizeOf(qkv.dtype()) * static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
static_cast<size_t>(token_num * num_chunks * tmp_m = allocator->Allocate(
num_heads * HEAD_DIM)); phi::SizeOf(paddle::DataType::FLOAT32) *
tmp_m = allocator->Allocate( static_cast<size_t>(token_num * num_chunks * num_heads));
phi::SizeOf(paddle::DataType::FLOAT32) * tmp_d = allocator->Allocate(
static_cast<size_t>(token_num * num_chunks * num_heads)); phi::SizeOf(paddle::DataType::FLOAT32) *
tmp_d = allocator->Allocate( static_cast<size_t>(token_num * num_chunks * num_heads));
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
} else {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
} }
launchWithPdlWhenEnabled( launchWithPdlWhenEnabled(
split_kv_kernel, split_kv_kernel,

View File

@@ -169,17 +169,9 @@ __global__ void multi_query_append_attention_c4_kernel(
T *o_base_ptr_T = nullptr; T *o_base_ptr_T = nullptr;
OutT *o_base_ptr_int8 = nullptr; OutT *o_base_ptr_int8 = nullptr;
if constexpr (partition_kv) { if constexpr (partition_kv) {
if (ENABLE_PREFILL) { o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b<T>();
tid % 8 * num_elems_per_128b<T>();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
}
} else { } else {
o_base_ptr_int8 = out + o_offset; o_base_ptr_int8 = out + o_offset;
} }
@@ -485,18 +477,8 @@ __global__ void multi_query_append_attention_c4_kernel(
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
if (qo_idx - q_start_seq_id < q_len) { if (qo_idx - q_start_seq_id < q_len) {
uint32_t offset; uint32_t offset =
if (ENABLE_PREFILL) { (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
offset =
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
} else {
offset = ((batch_id * speculate_max_draft_token_num +
qo_idx_now / GROUP_SIZE) *
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
}
tmp_m[offset] = m_frag[fx][j]; tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j]; tmp_d[offset] = d_frag[fx][j];
} }
@@ -669,11 +651,9 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>(); tid % 8 * num_elems_per_128b<T>();
} else { } else {
o_base_ptr_T = o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
tmp_workspace + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + tid % 8 * num_elems_per_128b<T>();
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
} }
} }
const int *mask_offset_this_seq = const int *mask_offset_this_seq =
@@ -989,12 +969,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
qo_head_idx; qo_head_idx;
} else { } else {
offset = ((batch_id * speculate_max_draft_token_num + offset =
qo_idx_now / GROUP_SIZE) * (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
} }
tmp_m[offset] = m_frag[fx][j]; tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j]; tmp_d[offset] = d_frag[fx][j];
@@ -1108,10 +1084,7 @@ void MultiQueryAppendC4Attention(
const float ratio = static_cast<float>(num_blocks_need) / const float ratio = static_cast<float>(num_blocks_need) /
static_cast<float>(num_blocks_per_wave); static_cast<float>(num_blocks_per_wave);
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size); uint32_t chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_dec_len, chunk_size); const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
@@ -1188,30 +1161,15 @@ void MultiQueryAppendC4Attention(
sliding_window); sliding_window);
} else { } else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (ENABLE_PREFILL) { tmp_workspace = allocator->Allocate(
tmp_workspace = allocator->Allocate( phi::SizeOf(qkv.dtype()) *
phi::SizeOf(qkv.dtype()) * static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM)); tmp_m = allocator->Allocate(
tmp_m = allocator->Allocate( phi::SizeOf(paddle::DataType::FLOAT32) *
phi::SizeOf(paddle::DataType::FLOAT32) * static_cast<size_t>(token_num * num_chunks * num_heads));
static_cast<size_t>(token_num * num_chunks * num_heads)); tmp_d = allocator->Allocate(
tmp_d = allocator->Allocate( phi::SizeOf(paddle::DataType::FLOAT32) *
phi::SizeOf(paddle::DataType::FLOAT32) * static_cast<size_t>(token_num * num_chunks * num_heads));
static_cast<size_t>(token_num * num_chunks * num_heads));
} else {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
launchWithPdlWhenEnabled( launchWithPdlWhenEnabled(
split_kv_kernel, split_kv_kernel,
grids, grids,
@@ -1262,92 +1220,49 @@ void MultiQueryAppendC4Attention(
sliding_window); sliding_window);
// merge // merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>(); constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) { constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx;
constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 grids_merge(bsz, num_heads); dim3 blocks_merge(blockx, blocky);
dim3 blocks_merge(blockx, blocky); launchWithPdlWhenEnabled(
launchWithPdlWhenEnabled( merge_multi_chunks_v2_kernel<NV_TYPE,
merge_multi_chunks_decoder_kernel<NV_TYPE, vec_size,
vec_size, blocky,
blocky, HEAD_DIM,
HEAD_DIM, OUT_NV_TYPE,
OUT_NV_TYPE, ENABLE_PREFILL>,
ENABLE_PREFILL>, grids_merge,
grids_merge, blocks_merge,
blocks_merge, 0,
0, stream,
stream, reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()), static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_m->ptr()), static_cast<float *>(tmp_d->ptr()),
static_cast<float *>(tmp_d->ptr()), seq_lens_q.data<int>(),
seq_lens_q.data<int>(), seq_lens_kv.data<int>(),
seq_lens_kv.data<int>(), seq_lens_encoder.data<int>(),
seq_lens_encoder.data<int>(), batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>())) const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr, : nullptr,
sinks ? reinterpret_cast<NV_TYPE *>( sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>())) const_cast<T *>(sinks.get().data<T>()))
: nullptr, : nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()), reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound, quant_max_bound,
quant_min_bound, quant_min_bound,
in_scale, in_scale,
max_seq_len, max_seq_len,
num_chunks, num_chunks,
num_heads, num_heads,
chunk_size, chunk_size,
HEAD_DIM); HEAD_DIM,
} else { token_num,
constexpr int blockx = HEAD_DIM / vec_size; speculate_max_draft_token_num);
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
launchWithPdlWhenEnabled(
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
} }
} else { } else {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4; constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4;
@@ -1390,9 +1305,6 @@ void MultiQueryAppendC4Attention(
static_cast<float>(num_blocks_per_wave); static_cast<float>(num_blocks_per_wave);
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size); uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_seq_len, chunk_size); const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len; uint32_t attn_mask_len;
@@ -1490,31 +1402,15 @@ void MultiQueryAppendC4Attention(
phi::SizeOf(paddle::DataType::FLOAT32) * phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(bsz * num_chunks * num_heads)); static_cast<size_t>(bsz * num_chunks * num_heads));
} else { } else {
if (ENABLE_PREFILL) { tmp_workspace = allocator->Allocate(
tmp_workspace = phi::SizeOf(qkv.dtype()) *
allocator->Allocate(phi::SizeOf(qkv.dtype()) * static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
static_cast<size_t>(token_num * num_chunks * tmp_m = allocator->Allocate(
num_heads * HEAD_DIM)); phi::SizeOf(paddle::DataType::FLOAT32) *
tmp_m = allocator->Allocate( static_cast<size_t>(token_num * num_chunks * num_heads));
phi::SizeOf(paddle::DataType::FLOAT32) * tmp_d = allocator->Allocate(
static_cast<size_t>(token_num * num_chunks * num_heads)); phi::SizeOf(paddle::DataType::FLOAT32) *
tmp_d = allocator->Allocate( static_cast<size_t>(token_num * num_chunks * num_heads));
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
} else {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
} }
launchWithPdlWhenEnabled( launchWithPdlWhenEnabled(
split_kv_kernel, split_kv_kernel,

View File

@@ -178,17 +178,9 @@ __global__ void multi_query_append_attention_c8_kernel(
T *o_base_ptr_T = nullptr; T *o_base_ptr_T = nullptr;
OutT *o_base_ptr_int8 = nullptr; OutT *o_base_ptr_int8 = nullptr;
if constexpr (partition_kv) { if constexpr (partition_kv) {
if (ENABLE_PREFILL) { o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + tid % 8 * num_elems_per_128b<T>();
tid % 8 * num_elems_per_128b<T>();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
}
} else { } else {
o_base_ptr_int8 = out + o_offset; o_base_ptr_int8 = out + o_offset;
} }
@@ -532,18 +524,8 @@ __global__ void multi_query_append_attention_c8_kernel(
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
if (qo_idx - q_start_seq_id < q_len) { if (qo_idx - q_start_seq_id < q_len) {
uint32_t offset; uint32_t offset =
if (ENABLE_PREFILL) { (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
offset =
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
} else {
offset = ((batch_id * speculate_max_draft_token_num +
qo_idx_now / GROUP_SIZE) *
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
}
tmp_m[offset] = m_frag[fx][j]; tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j]; tmp_d[offset] = d_frag[fx][j];
} }
@@ -720,11 +702,9 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>(); tid % 8 * num_elems_per_128b<T>();
} else { } else {
o_base_ptr_T = o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
tmp_workspace + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + tid % 8 * num_elems_per_128b<T>();
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
} }
} }
const int *mask_offset_this_seq = const int *mask_offset_this_seq =
@@ -1083,12 +1063,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
qo_head_idx; qo_head_idx;
} else { } else {
offset = ((batch_id * speculate_max_draft_token_num + offset =
qo_idx_now / GROUP_SIZE) * (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
} }
tmp_m[offset] = m_frag[fx][j]; tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j]; tmp_d[offset] = d_frag[fx][j];
@@ -1218,10 +1194,7 @@ void MultiQueryAppendC8Attention(
const int dev_id = 0; const int dev_id = 0;
int sm_count; int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size); uint32_t chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_dec_len, chunk_size); const int num_chunks = div_up(max_dec_len, chunk_size);
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
dim3 blocks(32, num_warps); dim3 blocks(32, num_warps);
@@ -1315,30 +1288,15 @@ void MultiQueryAppendC8Attention(
sliding_window); sliding_window);
} else { } else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (ENABLE_PREFILL) { tmp_workspace = allocator->Allocate(
tmp_workspace = allocator->Allocate( phi::SizeOf(qkv.dtype()) *
phi::SizeOf(qkv.dtype()) * static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM)); tmp_m = allocator->Allocate(
tmp_m = allocator->Allocate( phi::SizeOf(paddle::DataType::FLOAT32) *
phi::SizeOf(paddle::DataType::FLOAT32) * static_cast<size_t>(token_num * num_chunks * num_heads));
static_cast<size_t>(token_num * num_chunks * num_heads)); tmp_d = allocator->Allocate(
tmp_d = allocator->Allocate( phi::SizeOf(paddle::DataType::FLOAT32) *
phi::SizeOf(paddle::DataType::FLOAT32) * static_cast<size_t>(token_num * num_chunks * num_heads));
static_cast<size_t>(token_num * num_chunks * num_heads));
} else {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
launchWithPdlWhenEnabled( launchWithPdlWhenEnabled(
split_kv_kernel, split_kv_kernel,
grids, grids,
@@ -1383,92 +1341,49 @@ void MultiQueryAppendC8Attention(
sliding_window); sliding_window);
// merge // merge
constexpr int vec_size = num_elems_per_128b<NV_TYPE>(); constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
if (is_decoder) { constexpr int blockx = HEAD_DIM / vec_size;
constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx;
constexpr int blocky = (128 + blockx - 1) / blockx; dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 grids_merge(bsz, num_heads); dim3 blocks_merge(blockx, blocky);
dim3 blocks_merge(blockx, blocky); launchWithPdlWhenEnabled(
launchWithPdlWhenEnabled( merge_multi_chunks_v2_kernel<NV_TYPE,
merge_multi_chunks_decoder_kernel<NV_TYPE, vec_size,
vec_size, blocky,
blocky, HEAD_DIM,
HEAD_DIM, OUT_NV_TYPE,
OUT_NV_TYPE, ENABLE_PREFILL>,
ENABLE_PREFILL>, grids_merge,
grids_merge, blocks_merge,
blocks_merge, 0,
0, stream,
stream, reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()), static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_m->ptr()), static_cast<float *>(tmp_d->ptr()),
static_cast<float *>(tmp_d->ptr()), seq_lens_q.data<int>(),
seq_lens_q.data<int>(), seq_lens_kv.data<int>(),
seq_lens_kv.data<int>(), seq_lens_encoder.data<int>(),
seq_lens_encoder.data<int>(), batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(), cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>( shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>())) const_cast<T *>(shift_bias.get().data<T>()))
: nullptr, : nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>( smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>())) const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr, : nullptr,
sinks ? reinterpret_cast<NV_TYPE *>( sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>())) const_cast<T *>(sinks.get().data<T>()))
: nullptr, : nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()), reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound, quant_max_bound,
quant_min_bound, quant_min_bound,
in_scale, in_scale,
max_seq_len, max_seq_len,
num_chunks, num_chunks,
num_heads, num_heads,
chunk_size, chunk_size,
HEAD_DIM); HEAD_DIM,
} else { token_num,
constexpr int blockx = HEAD_DIM / vec_size; speculate_max_draft_token_num);
constexpr int blocky = (128 + blockx - 1) / blockx;
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
dim3 blocks_merge(blockx, blocky);
launchWithPdlWhenEnabled(
merge_multi_chunks_v2_kernel<NV_TYPE,
vec_size,
blocky,
HEAD_DIM,
OUT_NV_TYPE,
ENABLE_PREFILL>,
grids_merge,
blocks_merge,
0,
stream,
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
static_cast<float *>(tmp_m->ptr()),
static_cast<float *>(tmp_d->ptr()),
seq_lens_q.data<int>(),
seq_lens_kv.data<int>(),
seq_lens_encoder.data<int>(),
batch_id_per_token.data<int>(),
cu_seqlens_q.data<int>(),
shift_bias ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(shift_bias.get().data<T>()))
: nullptr,
smooth_weight ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(smooth_weight.get().data<T>()))
: nullptr,
sinks ? reinterpret_cast<NV_TYPE *>(
const_cast<T *>(sinks.get().data<T>()))
: nullptr,
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
quant_max_bound,
quant_min_bound,
in_scale,
max_seq_len,
num_chunks,
num_heads,
chunk_size,
HEAD_DIM,
token_num,
speculate_max_draft_token_num);
}
} }
} else { } else {
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2; constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
@@ -1525,9 +1440,6 @@ void MultiQueryAppendC8Attention(
int sm_count; int sm_count;
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size); uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
if (!is_decoder) {
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
}
const int num_chunks = div_up(max_seq_len, chunk_size); const int num_chunks = div_up(max_seq_len, chunk_size);
uint32_t attn_mask_len; uint32_t attn_mask_len;
@@ -1643,31 +1555,15 @@ void MultiQueryAppendC8Attention(
phi::SizeOf(paddle::DataType::FLOAT32) * phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(bsz * num_chunks * num_heads)); static_cast<size_t>(bsz * num_chunks * num_heads));
} else { } else {
if (ENABLE_PREFILL) { tmp_workspace = allocator->Allocate(
tmp_workspace = phi::SizeOf(qkv.dtype()) *
allocator->Allocate(phi::SizeOf(qkv.dtype()) * static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
static_cast<size_t>(token_num * num_chunks * tmp_m = allocator->Allocate(
num_heads * HEAD_DIM)); phi::SizeOf(paddle::DataType::FLOAT32) *
tmp_m = allocator->Allocate( static_cast<size_t>(token_num * num_chunks * num_heads));
phi::SizeOf(paddle::DataType::FLOAT32) * tmp_d = allocator->Allocate(
static_cast<size_t>(token_num * num_chunks * num_heads)); phi::SizeOf(paddle::DataType::FLOAT32) *
tmp_d = allocator->Allocate( static_cast<size_t>(token_num * num_chunks * num_heads));
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
} else {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
} }
launchWithPdlWhenEnabled( launchWithPdlWhenEnabled(
split_kv_kernel, split_kv_kernel,