[BugFix] Rename attention params of deepseekv3 (#2939)

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2025-07-22 14:01:30 +08:00
committed by GitHub
parent 56102e91e1
commit 8020927f50
7 changed files with 43 additions and 44 deletions

View File

@@ -91,7 +91,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
meta_data.max_blocks_per_seq = block_tables.dims()[1]; meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2]; meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cu_seqlens_q.dims()[0]; meta_data.batch_size = seq_lens_decoder.dims()[0];
switch (kv_pe.dtype()) { switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: { case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data, return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
@@ -224,7 +224,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
meta_data.max_blocks_per_seq = block_tables.dims()[1]; meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2]; meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cu_seqlens_q.dims()[0]; meta_data.batch_size = seq_lens_encoder.dims()[0];
switch (kv_pe.dtype()) { switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: { case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data, return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,

View File

@@ -72,7 +72,7 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets, const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids, const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch, const paddle::Tensor& tile_ids_per_batch,
@@ -130,7 +130,7 @@ void BatchMLAWithPagedKVCacheKernel(
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>()); params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>()); params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>()); params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
params.padding_offsets = const_cast<int*>(padding_offsets.data<int>()); params.batch_id_per_token = const_cast<int*>(batch_id_per_token.data<int>());
params.batch_ids = const_cast<int*>(batch_ids.data<int>()); params.batch_ids = const_cast<int*>(batch_ids.data<int>());
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>()); params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>()); params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
@@ -143,7 +143,6 @@ void BatchMLAWithPagedKVCacheKernel(
params.o_stride_head_num = v_head_dim; params.o_stride_head_num = v_head_dim;
params.bsz = bsz; params.bsz = bsz;
params.token_num = token_num; params.token_num = token_num;
params.max_seq_len = max_seq_len;
params.max_block_num = max_block_num; params.max_block_num = max_block_num;
params.max_block_num_per_seq = max_block_num_per_seq; params.max_block_num_per_seq = max_block_num_per_seq;
params.q_num_head = q_head_num; params.q_num_head = q_head_num;
@@ -179,7 +178,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets, const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids, const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch, const paddle::Tensor& tile_ids_per_batch,
@@ -213,7 +212,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets, const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids, const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch, const paddle::Tensor& tile_ids_per_batch,

View File

@@ -49,7 +49,7 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder, const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets, const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids, const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch, const paddle::Tensor& tile_ids_per_batch,

View File

@@ -71,7 +71,7 @@ struct Params {
alignas(16) IdType *seq_lens_encoder; alignas(16) IdType *seq_lens_encoder;
alignas(16) IdType *seq_lens_decoder; alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens; alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *padding_offsets; alignas(16) IdType *batch_id_per_token;
alignas(16) IdType *batch_ids; alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch; alignas(16) IdType *tile_ids_per_batch;
@@ -89,7 +89,6 @@ struct Params {
int bsz; int bsz;
int token_num; int token_num;
int max_seq_len;
int max_block_num; int max_block_num;
int max_block_num_per_seq; int max_block_num_per_seq;
int q_num_head; int q_num_head;
@@ -527,9 +526,9 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
params.seq_lens_this_time, params.seq_lens_this_time,
params.seq_lens_decoder, params.seq_lens_decoder,
params.seq_lens_encoder, params.seq_lens_encoder,
params.padding_offsets, params.cumsum_q_seqlens,
params.batch_id_per_token,
reinterpret_cast<NV_TYPE*>(params.O), reinterpret_cast<NV_TYPE*>(params.O),
params.max_seq_len,
params.chunk_num, params.chunk_num,
params.q_num_head, params.q_num_head,
params.chunk_size, params.chunk_size,

View File

@@ -255,9 +255,9 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
const int * __restrict__ seq_lens_this_time, const int * __restrict__ seq_lens_this_time,
const int * __restrict__ seq_lens_decoder, const int * __restrict__ seq_lens_decoder,
const int * __restrict__ seq_lens_encoder, const int * __restrict__ seq_lens_encoder,
const int * __restrict__ padding_offsets, const int *__restrict__ cu_seqlens_q,
const int * __restrict__ batch_id_per_token,
T * __restrict__ out, // [token_num, num_heads, head_dim] T * __restrict__ out, // [token_num, num_heads, head_dim]
const int max_seq_len,
const int num_chunks, const int num_chunks,
const int num_heads, const int num_heads,
const int chunk_size, const int chunk_size,
@@ -270,11 +270,10 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
__shared__ T smem[bdy * HEAD_DIM]; __shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2]; __shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) { for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t ori_token_id = qid + padding_offsets[qid]; const uint32_t bid = batch_id_per_token[qid];
const uint32_t bid = ori_token_id / max_seq_len;
const int seq_len_q = seq_lens_this_time[bid]; const int seq_len_q = seq_lens_this_time[bid];
if (seq_len_q == 0) continue; if (seq_len_q == 0) continue;
const uint32_t local_seq_id = ori_token_id % max_seq_len; const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
int seq_len_kv = seq_lens_decoder[bid]; int seq_len_kv = seq_lens_decoder[bid];
if (seq_len_kv == 0) continue; if (seq_len_kv == 0) continue;
seq_len_kv += seq_len_q; seq_len_kv += seq_len_q;

View File

@@ -25,7 +25,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets, const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids, const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch, const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -97,7 +97,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
seq_lens_decoder, seq_lens_decoder,
seq_lens_encoder, seq_lens_encoder,
cu_seqlens_q, cu_seqlens_q,
padding_offsets, batch_id_per_token,
block_tables, block_tables,
decoder_batch_ids, decoder_batch_ids,
decoder_tile_ids_per_batch, decoder_tile_ids_per_batch,
@@ -125,7 +125,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
out_linear_smooths, out_linear_smooths,
seq_lens_this_time, // q_seq_len is 1 seq_lens_this_time, // q_seq_len is 1
seq_lens_decoder, seq_lens_decoder,
padding_offsets, batch_id_per_token,
cu_seqlens_q, cu_seqlens_q,
block_tables, block_tables,
max_input_length, max_input_length,
@@ -148,7 +148,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& seq_lens_decoder, const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time, const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q, const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets, const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables, const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids, const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch, const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -211,7 +211,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
seq_lens_decoder, seq_lens_decoder,
seq_lens_this_time, seq_lens_this_time,
cu_seqlens_q, cu_seqlens_q,
padding_offsets, batch_id_per_token,
block_tables, block_tables,
encoder_batch_ids, encoder_batch_ids,
encoder_tile_ids_per_batch, encoder_tile_ids_per_batch,
@@ -257,7 +257,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
seq_lens_decoder, seq_lens_decoder,
seq_lens_this_time, seq_lens_this_time,
cu_seqlens_q, cu_seqlens_q,
padding_offsets, batch_id_per_token,
block_tables, block_tables,
encoder_batch_ids, encoder_batch_ids,
encoder_tile_ids_per_batch, encoder_tile_ids_per_batch,
@@ -310,7 +310,7 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
const std::vector<int64_t>& seq_lens_decoder_shape, const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape, const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape, const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& padding_offsets_shape, const std::vector<int64_t>& batch_id_per_token_shape,
const std::vector<int64_t>& block_tables_shape, const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape, const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape, const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
@@ -364,7 +364,7 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
const paddle::DataType& seq_lens_decoder_dtype, const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype, const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype, const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& padding_offsets_dtype, const paddle::DataType& batch_id_per_token_dtype,
const paddle::DataType& block_tables_dtype, const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype, const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype, const paddle::DataType& encoder_tile_ids_per_batch_dtype,
@@ -418,7 +418,7 @@ PD_BUILD_OP(multi_head_latent_attention)
"seq_lens_decoder", "seq_lens_decoder",
"seq_lens_this_time", "seq_lens_this_time",
"cu_seqlens_q", "cu_seqlens_q",
"padding_offsets", "batch_id_per_token",
"block_tables", "block_tables",
"encoder_batch_ids", "encoder_batch_ids",
"encoder_tile_ids_per_batch", "encoder_tile_ids_per_batch",

View File

@@ -213,6 +213,10 @@ class MLAAttentionBackend(AttentionBackend):
self.attention_metadata: AttentionMetadata = metadata self.attention_metadata: AttentionMetadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(
metadata.decoder_tile_ids_per_batch, False)
def get_attntion_meta(self) -> AttentionMetadata: def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta""" """get_attntion_meta"""
return self.attention_metadata return self.attention_metadata
@@ -259,8 +263,8 @@ class MLAAttentionBackend(AttentionBackend):
latent_cache, latent_cache,
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.padding_offset, forward_meta.batch_id_per_token,
forward_meta.cum_offsets, forward_meta.cu_seqlens_q,
metadata.block_tables, metadata.block_tables,
"none", "none",
getattr(forward_meta, "max_input_length", -1), getattr(forward_meta, "max_input_length", -1),
@@ -298,7 +302,7 @@ class MLAAttentionBackend(AttentionBackend):
""" """
metadata = self.attention_metadata metadata = self.attention_metadata
if self.use_pd_disaggregation: if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata, metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index, layer.layer_id + self.start_layer_index,
@@ -317,8 +321,8 @@ class MLAAttentionBackend(AttentionBackend):
latent_cache, latent_cache,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.padding_offset, forward_meta.batch_id_per_token,
forward_meta.cum_offsets, forward_meta.cu_seqlens_q,
metadata.block_tables, metadata.block_tables,
"none", "none",
self.max_seq_len, self.max_seq_len,
@@ -334,8 +338,7 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_q,
forward_meta.padding_offset, forward_meta.batch_id_per_token,
forward_meta.cum_offsets,
metadata.block_tables, metadata.block_tables,
metadata.encoder_batch_ids, metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch, metadata.encoder_tile_ids_per_batch,
@@ -343,8 +346,8 @@ class MLAAttentionBackend(AttentionBackend):
metadata.kv_batch_ids, metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch, metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks, metadata.kv_num_blocks,
metadata.decoder_batch_ids, forward_meta.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch, forward_meta.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks, metadata.decoder_num_blocks,
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
metadata.max_enc_len_this_time, metadata.max_enc_len_this_time,
@@ -394,7 +397,7 @@ class MLAAttentionBackend(AttentionBackend):
speculate_decoder = self.speculative_method is not None speculate_decoder = self.speculative_method is not None
speculate_max_tokens = self.speculate_max_draft_token_num speculate_max_tokens = self.speculate_max_draft_token_num
if self.use_pd_disaggregation: if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise( metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata, metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index, layer.layer_id + self.start_layer_index,
@@ -409,8 +412,8 @@ class MLAAttentionBackend(AttentionBackend):
latent_cache, latent_cache,
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.padding_offset, forward_meta.batch_id_per_token,
forward_meta.cum_offsets, forward_meta.cu_seqlens_q,
metadata.block_tables, metadata.block_tables,
"none", "none",
self.max_seq_len, self.max_seq_len,
@@ -440,8 +443,8 @@ class MLAAttentionBackend(AttentionBackend):
latent_cache, latent_cache,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.padding_offset, forward_meta.batch_id_per_token,
forward_meta.cum_offsets, forward_meta.cu_seqlens_q,
metadata.block_tables, metadata.block_tables,
"none", "none",
self.max_seq_len, self.max_seq_len,
@@ -457,8 +460,7 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q, forward_meta.cu_seqlens_q,
forward_meta.padding_offset, forward_meta.batch_id_per_token,
forward_meta.cum_offsets,
metadata.block_tables, metadata.block_tables,
metadata.encoder_batch_ids, metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch, metadata.encoder_tile_ids_per_batch,
@@ -466,8 +468,8 @@ class MLAAttentionBackend(AttentionBackend):
metadata.kv_batch_ids, metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch, metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks, metadata.kv_num_blocks,
metadata.decoder_batch_ids, forward_meta.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch, forward_meta.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks, metadata.decoder_num_blocks,
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
metadata.max_enc_len_this_time, metadata.max_enc_len_this_time,