supports mtp split_kv_attn (#5343)

This commit is contained in:
lzy
2025-12-03 12:40:16 +08:00
committed by GitHub
parent dfeabee123
commit c71a44c7e5
4 changed files with 212 additions and 531 deletions

View File

@@ -2451,7 +2451,6 @@ __global__ void merge_multi_chunks_v2_kernel(
if (bid == -1) {
continue;
}
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
const int seq_len_q = seq_lens_q[bid];
if (seq_len_q == 0) continue;
int seq_len_kv = seq_lens_kv[bid];
@@ -2470,8 +2469,6 @@ __global__ void merge_multi_chunks_v2_kernel(
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
if (num_chunks_this_seq <= 1) {
continue;
} else if (!ENABLE_PREFILL) {
continue;
}
using LoadT = AlignedVector<T, vec_size>;
@@ -2497,32 +2494,14 @@ __global__ void merge_multi_chunks_v2_kernel(
}
#pragma unroll 2
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
uint32_t offset;
if (ENABLE_PREFILL) {
offset = (qid * num_chunks + i) * num_heads + hid;
} else {
offset =
((bid * speculate_max_draft_token_num + local_seq_id) * num_chunks +
i) *
num_heads +
hid;
}
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
float m_prev = m;
float d_prev = d;
const float m_now = multi_m[offset];
const float d_now = multi_d[offset];
m = max(m_prev, m_now);
if (ENABLE_PREFILL) {
offset =
(qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
vid * vec_size;
} else {
offset = ((bid * speculate_max_draft_token_num + local_seq_id) *
num_chunks * num_heads +
i * num_heads + hid) *
head_dim +
vid * vec_size;
}
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
vid * vec_size;
Load<T, vec_size>(&multi_out[offset], &load_vec);
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
const T scale1_T = static_cast<T>(scale1),