mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
supports mtp split_kv_attn (#5343)
This commit is contained in:
@@ -2451,7 +2451,6 @@ __global__ void merge_multi_chunks_v2_kernel(
|
|||||||
if (bid == -1) {
|
if (bid == -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
|
||||||
const int seq_len_q = seq_lens_q[bid];
|
const int seq_len_q = seq_lens_q[bid];
|
||||||
if (seq_len_q == 0) continue;
|
if (seq_len_q == 0) continue;
|
||||||
int seq_len_kv = seq_lens_kv[bid];
|
int seq_len_kv = seq_lens_kv[bid];
|
||||||
@@ -2470,8 +2469,6 @@ __global__ void merge_multi_chunks_v2_kernel(
|
|||||||
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
const int num_chunks_this_seq = div_up(seq_len_kv, chunk_size);
|
||||||
if (num_chunks_this_seq <= 1) {
|
if (num_chunks_this_seq <= 1) {
|
||||||
continue;
|
continue;
|
||||||
} else if (!ENABLE_PREFILL) {
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
using LoadT = AlignedVector<T, vec_size>;
|
using LoadT = AlignedVector<T, vec_size>;
|
||||||
@@ -2497,32 +2494,14 @@ __global__ void merge_multi_chunks_v2_kernel(
|
|||||||
}
|
}
|
||||||
#pragma unroll 2
|
#pragma unroll 2
|
||||||
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
for (int i = ty; i < num_chunks_this_seq; i += bdy) {
|
||||||
uint32_t offset;
|
uint32_t offset = (qid * num_chunks + i) * num_heads + hid;
|
||||||
if (ENABLE_PREFILL) {
|
|
||||||
offset = (qid * num_chunks + i) * num_heads + hid;
|
|
||||||
} else {
|
|
||||||
offset =
|
|
||||||
((bid * speculate_max_draft_token_num + local_seq_id) * num_chunks +
|
|
||||||
i) *
|
|
||||||
num_heads +
|
|
||||||
hid;
|
|
||||||
}
|
|
||||||
float m_prev = m;
|
float m_prev = m;
|
||||||
float d_prev = d;
|
float d_prev = d;
|
||||||
const float m_now = multi_m[offset];
|
const float m_now = multi_m[offset];
|
||||||
const float d_now = multi_d[offset];
|
const float d_now = multi_d[offset];
|
||||||
m = max(m_prev, m_now);
|
m = max(m_prev, m_now);
|
||||||
if (ENABLE_PREFILL) {
|
offset = (qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
|
||||||
offset =
|
vid * vec_size;
|
||||||
(qid * num_chunks * num_heads + i * num_heads + hid) * head_dim +
|
|
||||||
vid * vec_size;
|
|
||||||
} else {
|
|
||||||
offset = ((bid * speculate_max_draft_token_num + local_seq_id) *
|
|
||||||
num_chunks * num_heads +
|
|
||||||
i * num_heads + hid) *
|
|
||||||
head_dim +
|
|
||||||
vid * vec_size;
|
|
||||||
}
|
|
||||||
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
Load<T, vec_size>(&multi_out[offset], &load_vec);
|
||||||
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
|
const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m);
|
||||||
const T scale1_T = static_cast<T>(scale1),
|
const T scale1_T = static_cast<T>(scale1),
|
||||||
|
|||||||
@@ -134,17 +134,9 @@ __global__ void multi_query_append_attention_kernel(
|
|||||||
T *o_base_ptr_T = nullptr;
|
T *o_base_ptr_T = nullptr;
|
||||||
OutT *o_base_ptr_int8 = nullptr;
|
OutT *o_base_ptr_int8 = nullptr;
|
||||||
if constexpr (partition_kv) {
|
if constexpr (partition_kv) {
|
||||||
if (ENABLE_PREFILL) {
|
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
|
||||||
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
|
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
} else {
|
|
||||||
o_base_ptr_T =
|
|
||||||
tmp_workspace +
|
|
||||||
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
|
|
||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
o_base_ptr_int8 = out + o_offset;
|
o_base_ptr_int8 = out + o_offset;
|
||||||
}
|
}
|
||||||
@@ -394,18 +386,8 @@ __global__ void multi_query_append_attention_kernel(
|
|||||||
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
|
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
|
||||||
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
|
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
|
||||||
if (qo_idx - q_start_seq_id < q_len) {
|
if (qo_idx - q_start_seq_id < q_len) {
|
||||||
uint32_t offset;
|
uint32_t offset =
|
||||||
if (ENABLE_PREFILL) {
|
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
|
||||||
offset =
|
|
||||||
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
|
|
||||||
} else {
|
|
||||||
offset = ((batch_id * speculate_max_draft_token_num +
|
|
||||||
qo_idx_now / GROUP_SIZE) *
|
|
||||||
num_chunks +
|
|
||||||
chunk_idx) *
|
|
||||||
q_num_heads +
|
|
||||||
qo_head_idx;
|
|
||||||
}
|
|
||||||
tmp_m[offset] = m_frag[fx][j];
|
tmp_m[offset] = m_frag[fx][j];
|
||||||
tmp_d[offset] = d_frag[fx][j];
|
tmp_d[offset] = d_frag[fx][j];
|
||||||
}
|
}
|
||||||
@@ -542,11 +524,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
|||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
} else {
|
} else {
|
||||||
o_base_ptr_T =
|
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
|
||||||
tmp_workspace +
|
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||||
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const int *mask_offset_this_seq =
|
const int *mask_offset_this_seq =
|
||||||
@@ -814,12 +794,8 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
|
|||||||
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
|
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
|
||||||
qo_head_idx;
|
qo_head_idx;
|
||||||
} else {
|
} else {
|
||||||
offset = ((batch_id * speculate_max_draft_token_num +
|
offset =
|
||||||
qo_idx_now / GROUP_SIZE) *
|
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
|
||||||
num_chunks +
|
|
||||||
chunk_idx) *
|
|
||||||
q_num_heads +
|
|
||||||
qo_head_idx;
|
|
||||||
}
|
}
|
||||||
tmp_m[offset] = m_frag[fx][j];
|
tmp_m[offset] = m_frag[fx][j];
|
||||||
tmp_d[offset] = d_frag[fx][j];
|
tmp_d[offset] = d_frag[fx][j];
|
||||||
@@ -918,10 +894,7 @@ void MultiQueryAppendAttention(
|
|||||||
int sm_count;
|
int sm_count;
|
||||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||||
|
|
||||||
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
|
uint32_t chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||||
if (!is_decoder) {
|
|
||||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
|
||||||
}
|
|
||||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||||
dim3 blocks(32, num_warps);
|
dim3 blocks(32, num_warps);
|
||||||
@@ -1053,95 +1026,51 @@ void MultiQueryAppendAttention(
|
|||||||
sliding_window);
|
sliding_window);
|
||||||
// merge
|
// merge
|
||||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||||
if (is_decoder) {
|
constexpr int blockx = HEAD_DIM / vec_size;
|
||||||
constexpr int blockx = HEAD_DIM / vec_size;
|
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
dim3 grids_merge(min(sm_count * 4, token_num),
|
||||||
dim3 grids_merge(bsz, num_heads);
|
num_heads); // 128k is too large
|
||||||
dim3 blocks_merge(blockx, blocky);
|
dim3 blocks_merge(blockx, blocky);
|
||||||
auto *kernelFn = merge_multi_chunks_decoder_kernel<NV_TYPE,
|
auto *kernelFn = merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||||
vec_size,
|
vec_size,
|
||||||
blocky,
|
blocky,
|
||||||
HEAD_DIM,
|
HEAD_DIM,
|
||||||
OUT_NV_TYPE,
|
OUT_NV_TYPE,
|
||||||
ENABLE_PREFILL>;
|
ENABLE_PREFILL>;
|
||||||
launchWithPdlWhenEnabled(
|
launchWithPdlWhenEnabled(
|
||||||
kernelFn,
|
kernelFn,
|
||||||
grids_merge,
|
grids_merge,
|
||||||
blocks_merge,
|
blocks_merge,
|
||||||
0,
|
0,
|
||||||
stream,
|
stream,
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
static_cast<float *>(tmp_d->ptr()),
|
||||||
seq_lens_q.data<int>(),
|
seq_lens_q.data<int>(),
|
||||||
seq_lens_kv.data<int>(),
|
seq_lens_kv.data<int>(),
|
||||||
seq_lens_encoder.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
cu_seqlens_q.data<int>(),
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
: nullptr,
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
: nullptr,
|
||||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||||
: nullptr,
|
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
: nullptr,
|
||||||
const_cast<T *>(sinks.get().data<T>()))
|
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||||
: nullptr,
|
const_cast<T *>(sinks.get().data<T>()))
|
||||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
: nullptr,
|
||||||
quant_max_bound,
|
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||||
quant_min_bound,
|
quant_max_bound,
|
||||||
in_scale,
|
quant_min_bound,
|
||||||
max_seq_len,
|
in_scale,
|
||||||
num_chunks,
|
max_seq_len,
|
||||||
num_heads,
|
num_chunks,
|
||||||
chunk_size,
|
num_heads,
|
||||||
HEAD_DIM);
|
chunk_size,
|
||||||
} else {
|
HEAD_DIM,
|
||||||
constexpr int blockx = HEAD_DIM / vec_size;
|
token_num,
|
||||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
speculate_max_draft_token_num);
|
||||||
dim3 grids_merge(min(sm_count * 4, token_num),
|
|
||||||
num_heads); // 128k is too large
|
|
||||||
dim3 blocks_merge(blockx, blocky);
|
|
||||||
auto *kernelFn = merge_multi_chunks_v2_kernel<NV_TYPE,
|
|
||||||
vec_size,
|
|
||||||
blocky,
|
|
||||||
HEAD_DIM,
|
|
||||||
OUT_NV_TYPE,
|
|
||||||
ENABLE_PREFILL>;
|
|
||||||
launchWithPdlWhenEnabled(
|
|
||||||
kernelFn,
|
|
||||||
grids_merge,
|
|
||||||
blocks_merge,
|
|
||||||
0,
|
|
||||||
stream,
|
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
|
||||||
seq_lens_q.data<int>(),
|
|
||||||
seq_lens_kv.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
batch_id_per_token.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
|
||||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
|
||||||
const_cast<T *>(sinks.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
|
||||||
quant_max_bound,
|
|
||||||
quant_min_bound,
|
|
||||||
in_scale,
|
|
||||||
max_seq_len,
|
|
||||||
num_chunks,
|
|
||||||
num_heads,
|
|
||||||
chunk_size,
|
|
||||||
HEAD_DIM,
|
|
||||||
token_num,
|
|
||||||
speculate_max_draft_token_num);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV;
|
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV;
|
||||||
@@ -1173,9 +1102,6 @@ void MultiQueryAppendAttention(
|
|||||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||||
|
|
||||||
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
|
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
|
||||||
if (!is_decoder) {
|
|
||||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t attn_mask_len;
|
uint32_t attn_mask_len;
|
||||||
if (attn_mask) {
|
if (attn_mask) {
|
||||||
@@ -1263,31 +1189,15 @@ void MultiQueryAppendAttention(
|
|||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||||
} else {
|
} else {
|
||||||
if (ENABLE_PREFILL) {
|
tmp_workspace = allocator->Allocate(
|
||||||
tmp_workspace =
|
phi::SizeOf(qkv.dtype()) *
|
||||||
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
|
static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
|
||||||
static_cast<size_t>(token_num * num_chunks *
|
tmp_m = allocator->Allocate(
|
||||||
num_heads * HEAD_DIM));
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
tmp_m = allocator->Allocate(
|
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
tmp_d = allocator->Allocate(
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
tmp_d = allocator->Allocate(
|
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
|
||||||
} else {
|
|
||||||
tmp_workspace = allocator->Allocate(
|
|
||||||
phi::SizeOf(qkv.dtype()) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads * HEAD_DIM));
|
|
||||||
tmp_m = allocator->Allocate(
|
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads));
|
|
||||||
tmp_d = allocator->Allocate(
|
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
launchWithPdlWhenEnabled(
|
launchWithPdlWhenEnabled(
|
||||||
split_kv_kernel,
|
split_kv_kernel,
|
||||||
|
|||||||
@@ -169,17 +169,9 @@ __global__ void multi_query_append_attention_c4_kernel(
|
|||||||
T *o_base_ptr_T = nullptr;
|
T *o_base_ptr_T = nullptr;
|
||||||
OutT *o_base_ptr_int8 = nullptr;
|
OutT *o_base_ptr_int8 = nullptr;
|
||||||
if constexpr (partition_kv) {
|
if constexpr (partition_kv) {
|
||||||
if (ENABLE_PREFILL) {
|
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
|
||||||
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
|
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
} else {
|
|
||||||
o_base_ptr_T =
|
|
||||||
tmp_workspace +
|
|
||||||
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
|
|
||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
o_base_ptr_int8 = out + o_offset;
|
o_base_ptr_int8 = out + o_offset;
|
||||||
}
|
}
|
||||||
@@ -485,18 +477,8 @@ __global__ void multi_query_append_attention_c4_kernel(
|
|||||||
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
|
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
|
||||||
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
|
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
|
||||||
if (qo_idx - q_start_seq_id < q_len) {
|
if (qo_idx - q_start_seq_id < q_len) {
|
||||||
uint32_t offset;
|
uint32_t offset =
|
||||||
if (ENABLE_PREFILL) {
|
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
|
||||||
offset =
|
|
||||||
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
|
|
||||||
} else {
|
|
||||||
offset = ((batch_id * speculate_max_draft_token_num +
|
|
||||||
qo_idx_now / GROUP_SIZE) *
|
|
||||||
num_chunks +
|
|
||||||
chunk_idx) *
|
|
||||||
q_num_heads +
|
|
||||||
qo_head_idx;
|
|
||||||
}
|
|
||||||
tmp_m[offset] = m_frag[fx][j];
|
tmp_m[offset] = m_frag[fx][j];
|
||||||
tmp_d[offset] = d_frag[fx][j];
|
tmp_d[offset] = d_frag[fx][j];
|
||||||
}
|
}
|
||||||
@@ -669,11 +651,9 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
|||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
} else {
|
} else {
|
||||||
o_base_ptr_T =
|
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
|
||||||
tmp_workspace +
|
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||||
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const int *mask_offset_this_seq =
|
const int *mask_offset_this_seq =
|
||||||
@@ -989,12 +969,8 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel(
|
|||||||
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
|
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
|
||||||
qo_head_idx;
|
qo_head_idx;
|
||||||
} else {
|
} else {
|
||||||
offset = ((batch_id * speculate_max_draft_token_num +
|
offset =
|
||||||
qo_idx_now / GROUP_SIZE) *
|
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
|
||||||
num_chunks +
|
|
||||||
chunk_idx) *
|
|
||||||
q_num_heads +
|
|
||||||
qo_head_idx;
|
|
||||||
}
|
}
|
||||||
tmp_m[offset] = m_frag[fx][j];
|
tmp_m[offset] = m_frag[fx][j];
|
||||||
tmp_d[offset] = d_frag[fx][j];
|
tmp_d[offset] = d_frag[fx][j];
|
||||||
@@ -1108,10 +1084,7 @@ void MultiQueryAppendC4Attention(
|
|||||||
const float ratio = static_cast<float>(num_blocks_need) /
|
const float ratio = static_cast<float>(num_blocks_need) /
|
||||||
static_cast<float>(num_blocks_per_wave);
|
static_cast<float>(num_blocks_per_wave);
|
||||||
|
|
||||||
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
|
uint32_t chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||||
if (!is_decoder) {
|
|
||||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
|
||||||
}
|
|
||||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||||
|
|
||||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||||
@@ -1188,30 +1161,15 @@ void MultiQueryAppendC4Attention(
|
|||||||
sliding_window);
|
sliding_window);
|
||||||
} else {
|
} else {
|
||||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||||
if (ENABLE_PREFILL) {
|
tmp_workspace = allocator->Allocate(
|
||||||
tmp_workspace = allocator->Allocate(
|
phi::SizeOf(qkv.dtype()) *
|
||||||
phi::SizeOf(qkv.dtype()) *
|
static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
|
tmp_m = allocator->Allocate(
|
||||||
tmp_m = allocator->Allocate(
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
tmp_d = allocator->Allocate(
|
||||||
tmp_d = allocator->Allocate(
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
|
||||||
} else {
|
|
||||||
tmp_workspace = allocator->Allocate(
|
|
||||||
phi::SizeOf(qkv.dtype()) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads * HEAD_DIM));
|
|
||||||
tmp_m = allocator->Allocate(
|
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads));
|
|
||||||
tmp_d = allocator->Allocate(
|
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads));
|
|
||||||
}
|
|
||||||
launchWithPdlWhenEnabled(
|
launchWithPdlWhenEnabled(
|
||||||
split_kv_kernel,
|
split_kv_kernel,
|
||||||
grids,
|
grids,
|
||||||
@@ -1262,92 +1220,49 @@ void MultiQueryAppendC4Attention(
|
|||||||
sliding_window);
|
sliding_window);
|
||||||
// merge
|
// merge
|
||||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||||
if (is_decoder) {
|
constexpr int blockx = HEAD_DIM / vec_size;
|
||||||
constexpr int blockx = HEAD_DIM / vec_size;
|
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
|
||||||
dim3 grids_merge(bsz, num_heads);
|
dim3 blocks_merge(blockx, blocky);
|
||||||
dim3 blocks_merge(blockx, blocky);
|
launchWithPdlWhenEnabled(
|
||||||
launchWithPdlWhenEnabled(
|
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||||
merge_multi_chunks_decoder_kernel<NV_TYPE,
|
vec_size,
|
||||||
vec_size,
|
blocky,
|
||||||
blocky,
|
HEAD_DIM,
|
||||||
HEAD_DIM,
|
OUT_NV_TYPE,
|
||||||
OUT_NV_TYPE,
|
ENABLE_PREFILL>,
|
||||||
ENABLE_PREFILL>,
|
grids_merge,
|
||||||
grids_merge,
|
blocks_merge,
|
||||||
blocks_merge,
|
0,
|
||||||
0,
|
stream,
|
||||||
stream,
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_d->ptr()),
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
seq_lens_q.data<int>(),
|
||||||
seq_lens_q.data<int>(),
|
seq_lens_kv.data<int>(),
|
||||||
seq_lens_kv.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
seq_lens_encoder.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(sinks.get().data<T>()))
|
const_cast<T *>(sinks.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||||
quant_max_bound,
|
quant_max_bound,
|
||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
num_chunks,
|
num_chunks,
|
||||||
num_heads,
|
num_heads,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
HEAD_DIM);
|
HEAD_DIM,
|
||||||
} else {
|
token_num,
|
||||||
constexpr int blockx = HEAD_DIM / vec_size;
|
speculate_max_draft_token_num);
|
||||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
|
||||||
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
|
|
||||||
dim3 blocks_merge(blockx, blocky);
|
|
||||||
launchWithPdlWhenEnabled(
|
|
||||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
|
||||||
vec_size,
|
|
||||||
blocky,
|
|
||||||
HEAD_DIM,
|
|
||||||
OUT_NV_TYPE,
|
|
||||||
ENABLE_PREFILL>,
|
|
||||||
grids_merge,
|
|
||||||
blocks_merge,
|
|
||||||
0,
|
|
||||||
stream,
|
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
|
||||||
seq_lens_q.data<int>(),
|
|
||||||
seq_lens_kv.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
batch_id_per_token.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
|
||||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
|
||||||
const_cast<T *>(sinks.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
|
||||||
quant_max_bound,
|
|
||||||
quant_min_bound,
|
|
||||||
in_scale,
|
|
||||||
max_seq_len,
|
|
||||||
num_chunks,
|
|
||||||
num_heads,
|
|
||||||
chunk_size,
|
|
||||||
HEAD_DIM,
|
|
||||||
token_num,
|
|
||||||
speculate_max_draft_token_num);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4;
|
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 4;
|
||||||
@@ -1390,9 +1305,6 @@ void MultiQueryAppendC4Attention(
|
|||||||
static_cast<float>(num_blocks_per_wave);
|
static_cast<float>(num_blocks_per_wave);
|
||||||
|
|
||||||
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
|
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
|
||||||
if (!is_decoder) {
|
|
||||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||||
uint32_t attn_mask_len;
|
uint32_t attn_mask_len;
|
||||||
@@ -1490,31 +1402,15 @@ void MultiQueryAppendC4Attention(
|
|||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||||
} else {
|
} else {
|
||||||
if (ENABLE_PREFILL) {
|
tmp_workspace = allocator->Allocate(
|
||||||
tmp_workspace =
|
phi::SizeOf(qkv.dtype()) *
|
||||||
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
|
static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
|
||||||
static_cast<size_t>(token_num * num_chunks *
|
tmp_m = allocator->Allocate(
|
||||||
num_heads * HEAD_DIM));
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
tmp_m = allocator->Allocate(
|
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
tmp_d = allocator->Allocate(
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
tmp_d = allocator->Allocate(
|
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
|
||||||
} else {
|
|
||||||
tmp_workspace = allocator->Allocate(
|
|
||||||
phi::SizeOf(qkv.dtype()) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads * HEAD_DIM));
|
|
||||||
tmp_m = allocator->Allocate(
|
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads));
|
|
||||||
tmp_d = allocator->Allocate(
|
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
launchWithPdlWhenEnabled(
|
launchWithPdlWhenEnabled(
|
||||||
split_kv_kernel,
|
split_kv_kernel,
|
||||||
|
|||||||
@@ -178,17 +178,9 @@ __global__ void multi_query_append_attention_c8_kernel(
|
|||||||
T *o_base_ptr_T = nullptr;
|
T *o_base_ptr_T = nullptr;
|
||||||
OutT *o_base_ptr_int8 = nullptr;
|
OutT *o_base_ptr_int8 = nullptr;
|
||||||
if constexpr (partition_kv) {
|
if constexpr (partition_kv) {
|
||||||
if (ENABLE_PREFILL) {
|
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
|
||||||
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
|
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
} else {
|
|
||||||
o_base_ptr_T =
|
|
||||||
tmp_workspace +
|
|
||||||
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
|
|
||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
o_base_ptr_int8 = out + o_offset;
|
o_base_ptr_int8 = out + o_offset;
|
||||||
}
|
}
|
||||||
@@ -532,18 +524,8 @@ __global__ void multi_query_append_attention_c8_kernel(
|
|||||||
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
|
const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE;
|
||||||
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
|
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;
|
||||||
if (qo_idx - q_start_seq_id < q_len) {
|
if (qo_idx - q_start_seq_id < q_len) {
|
||||||
uint32_t offset;
|
uint32_t offset =
|
||||||
if (ENABLE_PREFILL) {
|
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
|
||||||
offset =
|
|
||||||
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
|
|
||||||
} else {
|
|
||||||
offset = ((batch_id * speculate_max_draft_token_num +
|
|
||||||
qo_idx_now / GROUP_SIZE) *
|
|
||||||
num_chunks +
|
|
||||||
chunk_idx) *
|
|
||||||
q_num_heads +
|
|
||||||
qo_head_idx;
|
|
||||||
}
|
|
||||||
tmp_m[offset] = m_frag[fx][j];
|
tmp_m[offset] = m_frag[fx][j];
|
||||||
tmp_d[offset] = d_frag[fx][j];
|
tmp_d[offset] = d_frag[fx][j];
|
||||||
}
|
}
|
||||||
@@ -720,11 +702,9 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
|||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
} else {
|
} else {
|
||||||
o_base_ptr_T =
|
o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride +
|
||||||
tmp_workspace +
|
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
||||||
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
|
tid % 8 * num_elems_per_128b<T>();
|
||||||
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
|
|
||||||
tid % 8 * num_elems_per_128b<T>();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const int *mask_offset_this_seq =
|
const int *mask_offset_this_seq =
|
||||||
@@ -1083,12 +1063,8 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel(
|
|||||||
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
|
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
|
||||||
qo_head_idx;
|
qo_head_idx;
|
||||||
} else {
|
} else {
|
||||||
offset = ((batch_id * speculate_max_draft_token_num +
|
offset =
|
||||||
qo_idx_now / GROUP_SIZE) *
|
(qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx;
|
||||||
num_chunks +
|
|
||||||
chunk_idx) *
|
|
||||||
q_num_heads +
|
|
||||||
qo_head_idx;
|
|
||||||
}
|
}
|
||||||
tmp_m[offset] = m_frag[fx][j];
|
tmp_m[offset] = m_frag[fx][j];
|
||||||
tmp_d[offset] = d_frag[fx][j];
|
tmp_d[offset] = d_frag[fx][j];
|
||||||
@@ -1218,10 +1194,7 @@ void MultiQueryAppendC8Attention(
|
|||||||
const int dev_id = 0;
|
const int dev_id = 0;
|
||||||
int sm_count;
|
int sm_count;
|
||||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||||
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
|
uint32_t chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
||||||
if (!is_decoder) {
|
|
||||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
|
||||||
}
|
|
||||||
const int num_chunks = div_up(max_dec_len, chunk_size);
|
const int num_chunks = div_up(max_dec_len, chunk_size);
|
||||||
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads);
|
||||||
dim3 blocks(32, num_warps);
|
dim3 blocks(32, num_warps);
|
||||||
@@ -1315,30 +1288,15 @@ void MultiQueryAppendC8Attention(
|
|||||||
sliding_window);
|
sliding_window);
|
||||||
} else {
|
} else {
|
||||||
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
|
||||||
if (ENABLE_PREFILL) {
|
tmp_workspace = allocator->Allocate(
|
||||||
tmp_workspace = allocator->Allocate(
|
phi::SizeOf(qkv.dtype()) *
|
||||||
phi::SizeOf(qkv.dtype()) *
|
static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
|
tmp_m = allocator->Allocate(
|
||||||
tmp_m = allocator->Allocate(
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
tmp_d = allocator->Allocate(
|
||||||
tmp_d = allocator->Allocate(
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
|
||||||
} else {
|
|
||||||
tmp_workspace = allocator->Allocate(
|
|
||||||
phi::SizeOf(qkv.dtype()) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads * HEAD_DIM));
|
|
||||||
tmp_m = allocator->Allocate(
|
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads));
|
|
||||||
tmp_d = allocator->Allocate(
|
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads));
|
|
||||||
}
|
|
||||||
launchWithPdlWhenEnabled(
|
launchWithPdlWhenEnabled(
|
||||||
split_kv_kernel,
|
split_kv_kernel,
|
||||||
grids,
|
grids,
|
||||||
@@ -1383,92 +1341,49 @@ void MultiQueryAppendC8Attention(
|
|||||||
sliding_window);
|
sliding_window);
|
||||||
// merge
|
// merge
|
||||||
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
constexpr int vec_size = num_elems_per_128b<NV_TYPE>();
|
||||||
if (is_decoder) {
|
constexpr int blockx = HEAD_DIM / vec_size;
|
||||||
constexpr int blockx = HEAD_DIM / vec_size;
|
constexpr int blocky = (128 + blockx - 1) / blockx;
|
||||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
|
||||||
dim3 grids_merge(bsz, num_heads);
|
dim3 blocks_merge(blockx, blocky);
|
||||||
dim3 blocks_merge(blockx, blocky);
|
launchWithPdlWhenEnabled(
|
||||||
launchWithPdlWhenEnabled(
|
merge_multi_chunks_v2_kernel<NV_TYPE,
|
||||||
merge_multi_chunks_decoder_kernel<NV_TYPE,
|
vec_size,
|
||||||
vec_size,
|
blocky,
|
||||||
blocky,
|
HEAD_DIM,
|
||||||
HEAD_DIM,
|
OUT_NV_TYPE,
|
||||||
OUT_NV_TYPE,
|
ENABLE_PREFILL>,
|
||||||
ENABLE_PREFILL>,
|
grids_merge,
|
||||||
grids_merge,
|
blocks_merge,
|
||||||
blocks_merge,
|
0,
|
||||||
0,
|
stream,
|
||||||
stream,
|
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
static_cast<float *>(tmp_m->ptr()),
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
static_cast<float *>(tmp_d->ptr()),
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
seq_lens_q.data<int>(),
|
||||||
seq_lens_q.data<int>(),
|
seq_lens_kv.data<int>(),
|
||||||
seq_lens_kv.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
seq_lens_encoder.data<int>(),
|
batch_id_per_token.data<int>(),
|
||||||
cu_seqlens_q.data<int>(),
|
cu_seqlens_q.data<int>(),
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
const_cast<T *>(shift_bias.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
const_cast<T *>(smooth_weight.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
sinks ? reinterpret_cast<NV_TYPE *>(
|
||||||
const_cast<T *>(sinks.get().data<T>()))
|
const_cast<T *>(sinks.get().data<T>()))
|
||||||
: nullptr,
|
: nullptr,
|
||||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
||||||
quant_max_bound,
|
quant_max_bound,
|
||||||
quant_min_bound,
|
quant_min_bound,
|
||||||
in_scale,
|
in_scale,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
num_chunks,
|
num_chunks,
|
||||||
num_heads,
|
num_heads,
|
||||||
chunk_size,
|
chunk_size,
|
||||||
HEAD_DIM);
|
HEAD_DIM,
|
||||||
} else {
|
token_num,
|
||||||
constexpr int blockx = HEAD_DIM / vec_size;
|
speculate_max_draft_token_num);
|
||||||
constexpr int blocky = (128 + blockx - 1) / blockx;
|
|
||||||
dim3 grids_merge(min(sm_count * 4, token_num), num_heads);
|
|
||||||
dim3 blocks_merge(blockx, blocky);
|
|
||||||
launchWithPdlWhenEnabled(
|
|
||||||
merge_multi_chunks_v2_kernel<NV_TYPE,
|
|
||||||
vec_size,
|
|
||||||
blocky,
|
|
||||||
HEAD_DIM,
|
|
||||||
OUT_NV_TYPE,
|
|
||||||
ENABLE_PREFILL>,
|
|
||||||
grids_merge,
|
|
||||||
blocks_merge,
|
|
||||||
0,
|
|
||||||
stream,
|
|
||||||
reinterpret_cast<NV_TYPE *>(tmp_workspace->ptr()),
|
|
||||||
static_cast<float *>(tmp_m->ptr()),
|
|
||||||
static_cast<float *>(tmp_d->ptr()),
|
|
||||||
seq_lens_q.data<int>(),
|
|
||||||
seq_lens_kv.data<int>(),
|
|
||||||
seq_lens_encoder.data<int>(),
|
|
||||||
batch_id_per_token.data<int>(),
|
|
||||||
cu_seqlens_q.data<int>(),
|
|
||||||
shift_bias ? reinterpret_cast<NV_TYPE *>(
|
|
||||||
const_cast<T *>(shift_bias.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
smooth_weight ? reinterpret_cast<NV_TYPE *>(
|
|
||||||
const_cast<T *>(smooth_weight.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
sinks ? reinterpret_cast<NV_TYPE *>(
|
|
||||||
const_cast<T *>(sinks.get().data<T>()))
|
|
||||||
: nullptr,
|
|
||||||
reinterpret_cast<OUT_NV_TYPE *>(out->data<OutT>()),
|
|
||||||
quant_max_bound,
|
|
||||||
quant_min_bound,
|
|
||||||
in_scale,
|
|
||||||
max_seq_len,
|
|
||||||
num_chunks,
|
|
||||||
num_heads,
|
|
||||||
chunk_size,
|
|
||||||
HEAD_DIM,
|
|
||||||
token_num,
|
|
||||||
speculate_max_draft_token_num);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
|
constexpr uint32_t num_frags_z = BLOCK_SIZE / 16 / NUM_WARP_KV * 2;
|
||||||
@@ -1525,9 +1440,6 @@ void MultiQueryAppendC8Attention(
|
|||||||
int sm_count;
|
int sm_count;
|
||||||
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id);
|
||||||
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
|
uint32_t chunk_size = static_cast<uint32_t>(max_partition_size);
|
||||||
if (!is_decoder) {
|
|
||||||
chunk_size = static_cast<uint32_t>(encoder_max_partition_size);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int num_chunks = div_up(max_seq_len, chunk_size);
|
const int num_chunks = div_up(max_seq_len, chunk_size);
|
||||||
uint32_t attn_mask_len;
|
uint32_t attn_mask_len;
|
||||||
@@ -1643,31 +1555,15 @@ void MultiQueryAppendC8Attention(
|
|||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
static_cast<size_t>(bsz * num_chunks * num_heads));
|
static_cast<size_t>(bsz * num_chunks * num_heads));
|
||||||
} else {
|
} else {
|
||||||
if (ENABLE_PREFILL) {
|
tmp_workspace = allocator->Allocate(
|
||||||
tmp_workspace =
|
phi::SizeOf(qkv.dtype()) *
|
||||||
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
|
static_cast<size_t>(token_num * num_chunks * num_heads * HEAD_DIM));
|
||||||
static_cast<size_t>(token_num * num_chunks *
|
tmp_m = allocator->Allocate(
|
||||||
num_heads * HEAD_DIM));
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
tmp_m = allocator->Allocate(
|
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
tmp_d = allocator->Allocate(
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
phi::SizeOf(paddle::DataType::FLOAT32) *
|
||||||
tmp_d = allocator->Allocate(
|
static_cast<size_t>(token_num * num_chunks * num_heads));
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(token_num * num_chunks * num_heads));
|
|
||||||
} else {
|
|
||||||
tmp_workspace = allocator->Allocate(
|
|
||||||
phi::SizeOf(qkv.dtype()) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads * HEAD_DIM));
|
|
||||||
tmp_m = allocator->Allocate(
|
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads));
|
|
||||||
tmp_d = allocator->Allocate(
|
|
||||||
phi::SizeOf(paddle::DataType::FLOAT32) *
|
|
||||||
static_cast<size_t>(speculate_max_draft_token_num * bsz *
|
|
||||||
num_chunks * num_heads));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
launchWithPdlWhenEnabled(
|
launchWithPdlWhenEnabled(
|
||||||
split_kv_kernel,
|
split_kv_kernel,
|
||||||
|
|||||||
Reference in New Issue
Block a user