[BugFix] Rename attention params of deepseekv3 (#2939)

Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
K11OntheBoat
2025-07-22 14:01:30 +08:00
committed by GitHub
parent 56102e91e1
commit 8020927f50
7 changed files with 43 additions and 44 deletions

View File

@@ -91,7 +91,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cu_seqlens_q.dims()[0];
meta_data.batch_size = seq_lens_decoder.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
@@ -224,7 +224,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = kv_cache_dims[2];
meta_data.batch_size = cu_seqlens_q.dims()[0];
meta_data.batch_size = seq_lens_encoder.dims()[0];
switch (kv_pe.dtype()) {
case paddle::DataType::BFLOAT16: {
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,

View File

@@ -72,7 +72,7 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -130,7 +130,7 @@ void BatchMLAWithPagedKVCacheKernel(
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
params.padding_offsets = const_cast<int*>(padding_offsets.data<int>());
params.batch_id_per_token = const_cast<int*>(batch_id_per_token.data<int>());
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
@@ -143,7 +143,6 @@ void BatchMLAWithPagedKVCacheKernel(
params.o_stride_head_num = v_head_dim;
params.bsz = bsz;
params.token_num = token_num;
params.max_seq_len = max_seq_len;
params.max_block_num = max_block_num;
params.max_block_num_per_seq = max_block_num_per_seq;
params.q_num_head = q_head_num;
@@ -179,7 +178,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,
@@ -213,7 +212,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -49,7 +49,7 @@ void BatchMLAWithPagedKVCacheKernel(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& batch_ids,
const paddle::Tensor& tile_ids_per_batch,

View File

@@ -71,7 +71,7 @@ struct Params {
alignas(16) IdType *seq_lens_encoder;
alignas(16) IdType *seq_lens_decoder;
alignas(16) IdType *cumsum_q_seqlens;
alignas(16) IdType *padding_offsets;
alignas(16) IdType *batch_id_per_token;
alignas(16) IdType *batch_ids;
alignas(16) IdType *tile_ids_per_batch;
@@ -89,7 +89,6 @@ struct Params {
int bsz;
int token_num;
int max_seq_len;
int max_block_num;
int max_block_num_per_seq;
int q_num_head;
@@ -527,9 +526,9 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
params.seq_lens_this_time,
params.seq_lens_decoder,
params.seq_lens_encoder,
params.padding_offsets,
params.cumsum_q_seqlens,
params.batch_id_per_token,
reinterpret_cast<NV_TYPE*>(params.O),
params.max_seq_len,
params.chunk_num,
params.q_num_head,
params.chunk_size,

View File

@@ -255,9 +255,9 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
const int * __restrict__ seq_lens_this_time,
const int * __restrict__ seq_lens_decoder,
const int * __restrict__ seq_lens_encoder,
const int * __restrict__ padding_offsets,
const int *__restrict__ cu_seqlens_q,
const int * __restrict__ batch_id_per_token,
T * __restrict__ out, // [token_num, num_heads, head_dim]
const int max_seq_len,
const int num_chunks,
const int num_heads,
const int chunk_size,
@@ -270,11 +270,10 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
__shared__ T smem[bdy * HEAD_DIM];
__shared__ float md_smem[bdy * 2];
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
const uint32_t ori_token_id = qid + padding_offsets[qid];
const uint32_t bid = ori_token_id / max_seq_len;
const uint32_t bid = batch_id_per_token[qid];
const int seq_len_q = seq_lens_this_time[bid];
if (seq_len_q == 0) continue;
const uint32_t local_seq_id = ori_token_id % max_seq_len;
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
int seq_len_kv = seq_lens_decoder[bid];
if (seq_len_kv == 0) continue;
seq_len_kv += seq_len_q;

View File

@@ -25,7 +25,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -97,7 +97,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
seq_lens_decoder,
seq_lens_encoder,
cu_seqlens_q,
padding_offsets,
batch_id_per_token,
block_tables,
decoder_batch_ids,
decoder_tile_ids_per_batch,
@@ -125,7 +125,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
out_linear_smooths,
seq_lens_this_time, // q_seq_len is 1
seq_lens_decoder,
padding_offsets,
batch_id_per_token,
cu_seqlens_q,
block_tables,
max_input_length,
@@ -148,7 +148,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& cu_seqlens_q,
const paddle::Tensor& padding_offsets,
const paddle::Tensor& batch_id_per_token,
const paddle::Tensor& block_tables,
const paddle::Tensor& encoder_batch_ids,
const paddle::Tensor& encoder_tile_ids_per_batch,
@@ -211,7 +211,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_q,
padding_offsets,
batch_id_per_token,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -257,7 +257,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
seq_lens_decoder,
seq_lens_this_time,
cu_seqlens_q,
padding_offsets,
batch_id_per_token,
block_tables,
encoder_batch_ids,
encoder_tile_ids_per_batch,
@@ -310,7 +310,7 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_this_time_shape,
const std::vector<int64_t>& cu_seqlens_q_shape,
const std::vector<int64_t>& padding_offsets_shape,
const std::vector<int64_t>& batch_id_per_token_shape,
const std::vector<int64_t>& block_tables_shape,
const std::vector<int64_t>& encoder_batch_ids_shape,
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
@@ -364,7 +364,7 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_this_time_dtype,
const paddle::DataType& cu_seqlens_q_dtype,
const paddle::DataType& padding_offsets_dtype,
const paddle::DataType& batch_id_per_token_dtype,
const paddle::DataType& block_tables_dtype,
const paddle::DataType& encoder_batch_ids_dtype,
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
@@ -418,7 +418,7 @@ PD_BUILD_OP(multi_head_latent_attention)
"seq_lens_decoder",
"seq_lens_this_time",
"cu_seqlens_q",
"padding_offsets",
"batch_id_per_token",
"block_tables",
"encoder_batch_ids",
"encoder_tile_ids_per_batch",

View File

@@ -213,6 +213,10 @@ class MLAAttentionBackend(AttentionBackend):
self.attention_metadata: AttentionMetadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(
metadata.decoder_tile_ids_per_batch, False)
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
return self.attention_metadata
@@ -259,8 +263,8 @@ class MLAAttentionBackend(AttentionBackend):
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
getattr(forward_meta, "max_input_length", -1),
@@ -298,7 +302,7 @@ class MLAAttentionBackend(AttentionBackend):
"""
metadata = self.attention_metadata
if self.use_pd_disaggregation:
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index,
@@ -317,8 +321,8 @@ class MLAAttentionBackend(AttentionBackend):
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
self.max_seq_len,
@@ -334,8 +338,7 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.padding_offset,
forward_meta.cum_offsets,
forward_meta.batch_id_per_token,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
@@ -343,8 +346,8 @@ class MLAAttentionBackend(AttentionBackend):
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
metadata.max_enc_len_this_time,
@@ -394,7 +397,7 @@ class MLAAttentionBackend(AttentionBackend):
speculate_decoder = self.speculative_method is not None
speculate_max_tokens = self.speculate_max_draft_token_num
if self.use_pd_disaggregation:
if self.pd_disaggregation_mode == "per_query":
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
metadata.kv_signal_metadata,
layer.layer_id + self.start_layer_index,
@@ -409,8 +412,8 @@ class MLAAttentionBackend(AttentionBackend):
latent_cache,
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
self.max_seq_len,
@@ -440,8 +443,8 @@ class MLAAttentionBackend(AttentionBackend):
latent_cache,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_encoder,
forward_meta.padding_offset,
forward_meta.cum_offsets,
forward_meta.batch_id_per_token,
forward_meta.cu_seqlens_q,
metadata.block_tables,
"none",
self.max_seq_len,
@@ -457,8 +460,7 @@ class MLAAttentionBackend(AttentionBackend):
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
forward_meta.cu_seqlens_q,
forward_meta.padding_offset,
forward_meta.cum_offsets,
forward_meta.batch_id_per_token,
metadata.block_tables,
metadata.encoder_batch_ids,
metadata.encoder_tile_ids_per_batch,
@@ -466,8 +468,8 @@ class MLAAttentionBackend(AttentionBackend):
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
metadata.max_enc_len_this_time,