mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Others] Maintain the mtp branch temporarily. (#5446)
This commit is contained in:
@@ -2451,6 +2451,7 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
if (bid == -1) {
|
||||
continue;
|
||||
}
|
||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||
const int seq_len_q = seq_lens_q[bid];
|
||||
if (seq_len_q == 0) continue;
|
||||
int seq_len_kv = seq_lens_kv[bid];
|
||||
@@ -2494,14 +2495,32 @@ __global__ void merge_multi_chunks_v2_kernel(
|
||||
}
|
||||
#pragma unroll 2
|
||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
|
||||
uint32_t offset;
|
||||
if (ENABLE_PREFILL) {
|
||||
offset = (qid * num_chunks + i) * num_heads + hid;
|
||||
} else {
|
||||
offset =
|
||||
((bid * speculate_max_draft_token_num + local_seq_id) * num_chunks +
|
||||
i) *
|
||||
num_heads +
|
||||
hid;
|
||||
}
|
||||
float m_prev = m;
|
||||
float d_prev = d;
|
||||
const float m_now = multi_m[offset];
|
||||
const float d_now = multi_d[offset];
|
||||
m = max(m_prev, m_now);
|
||||
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
|
||||
vid * vec_size;
|
||||
if (ENABLE_PREFILL) {
|
||||
offset =
|
||||
(qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
|
||||
vid * vec_size;
|
||||
} else {
|
||||
offset = ((bid * speculate_max_draft_token_num + local_seq_id) *
|
||||
num_chunks * num_heads +
|
||||
i * num_heads + hid) *
|
||||
head_dim +
|
||||
vid * vec_size;
|
||||
}
|
||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
|
||||
const T scale1_T = static_cast<T>(scale1),
|
||||
|
||||
Reference in New Issue
Block a user