mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-22 00:02:10 +08:00
[BugFix] Rename attention params of deepseekv3 (#2939)
Co-authored-by: K11OntheBoat <“ruianmaidanglao@163.com”>
This commit is contained in:
@@ -91,7 +91,7 @@ std::vector<paddle::Tensor> PrefillMLAWriteCacheKernel(
|
|||||||
|
|
||||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||||
meta_data.block_size = kv_cache_dims[2];
|
meta_data.block_size = kv_cache_dims[2];
|
||||||
meta_data.batch_size = cu_seqlens_q.dims()[0];
|
meta_data.batch_size = seq_lens_decoder.dims()[0];
|
||||||
switch (kv_pe.dtype()) {
|
switch (kv_pe.dtype()) {
|
||||||
case paddle::DataType::BFLOAT16: {
|
case paddle::DataType::BFLOAT16: {
|
||||||
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
return PrefillMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||||
@@ -224,7 +224,7 @@ std::vector<paddle::Tensor> DecodeMLAWriteCacheKernel(
|
|||||||
|
|
||||||
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
meta_data.max_blocks_per_seq = block_tables.dims()[1];
|
||||||
meta_data.block_size = kv_cache_dims[2];
|
meta_data.block_size = kv_cache_dims[2];
|
||||||
meta_data.batch_size = cu_seqlens_q.dims()[0];
|
meta_data.batch_size = seq_lens_encoder.dims()[0];
|
||||||
switch (kv_pe.dtype()) {
|
switch (kv_pe.dtype()) {
|
||||||
case paddle::DataType::BFLOAT16: {
|
case paddle::DataType::BFLOAT16: {
|
||||||
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
return DecodeMLAWriteCache<paddle::DataType::BFLOAT16>(meta_data,
|
||||||
|
@@ -72,7 +72,7 @@ void BatchMLAWithPagedKVCacheKernel(
|
|||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& padding_offsets,
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::Tensor& batch_ids,
|
const paddle::Tensor& batch_ids,
|
||||||
const paddle::Tensor& tile_ids_per_batch,
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
@@ -130,7 +130,7 @@ void BatchMLAWithPagedKVCacheKernel(
|
|||||||
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
|
params.seq_lens_encoder = const_cast<int*>(seq_lens_encoder.data<int>());
|
||||||
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
|
params.seq_lens_decoder = const_cast<int*>(seq_lens_decoder.data<int>());
|
||||||
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
|
params.cumsum_q_seqlens = const_cast<int*>(cu_seqlens_q.data<int>());
|
||||||
params.padding_offsets = const_cast<int*>(padding_offsets.data<int>());
|
params.batch_id_per_token = const_cast<int*>(batch_id_per_token.data<int>());
|
||||||
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
|
params.batch_ids = const_cast<int*>(batch_ids.data<int>());
|
||||||
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
|
params.tile_ids_per_batch = const_cast<int*>(tile_ids_per_batch.data<int>());
|
||||||
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
|
params.num_blocks_x = const_cast<int*>(num_blocks_x_device.data<int>());
|
||||||
@@ -143,7 +143,6 @@ void BatchMLAWithPagedKVCacheKernel(
|
|||||||
params.o_stride_head_num = v_head_dim;
|
params.o_stride_head_num = v_head_dim;
|
||||||
params.bsz = bsz;
|
params.bsz = bsz;
|
||||||
params.token_num = token_num;
|
params.token_num = token_num;
|
||||||
params.max_seq_len = max_seq_len;
|
|
||||||
params.max_block_num = max_block_num;
|
params.max_block_num = max_block_num;
|
||||||
params.max_block_num_per_seq = max_block_num_per_seq;
|
params.max_block_num_per_seq = max_block_num_per_seq;
|
||||||
params.q_num_head = q_head_num;
|
params.q_num_head = q_head_num;
|
||||||
@@ -179,7 +178,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::bfloat16>(
|
|||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& padding_offsets,
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::Tensor& batch_ids,
|
const paddle::Tensor& batch_ids,
|
||||||
const paddle::Tensor& tile_ids_per_batch,
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
@@ -213,7 +212,7 @@ template void BatchMLAWithPagedKVCacheKernel<paddle::float16>(
|
|||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& padding_offsets,
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::Tensor& batch_ids,
|
const paddle::Tensor& batch_ids,
|
||||||
const paddle::Tensor& tile_ids_per_batch,
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
@@ -49,7 +49,7 @@ void BatchMLAWithPagedKVCacheKernel(
|
|||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& seq_lens_encoder,
|
const paddle::Tensor& seq_lens_encoder,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& padding_offsets,
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::Tensor& batch_ids,
|
const paddle::Tensor& batch_ids,
|
||||||
const paddle::Tensor& tile_ids_per_batch,
|
const paddle::Tensor& tile_ids_per_batch,
|
||||||
|
@@ -71,7 +71,7 @@ struct Params {
|
|||||||
alignas(16) IdType *seq_lens_encoder;
|
alignas(16) IdType *seq_lens_encoder;
|
||||||
alignas(16) IdType *seq_lens_decoder;
|
alignas(16) IdType *seq_lens_decoder;
|
||||||
alignas(16) IdType *cumsum_q_seqlens;
|
alignas(16) IdType *cumsum_q_seqlens;
|
||||||
alignas(16) IdType *padding_offsets;
|
alignas(16) IdType *batch_id_per_token;
|
||||||
|
|
||||||
alignas(16) IdType *batch_ids;
|
alignas(16) IdType *batch_ids;
|
||||||
alignas(16) IdType *tile_ids_per_batch;
|
alignas(16) IdType *tile_ids_per_batch;
|
||||||
@@ -89,7 +89,6 @@ struct Params {
|
|||||||
|
|
||||||
int bsz;
|
int bsz;
|
||||||
int token_num;
|
int token_num;
|
||||||
int max_seq_len;
|
|
||||||
int max_block_num;
|
int max_block_num;
|
||||||
int max_block_num_per_seq;
|
int max_block_num_per_seq;
|
||||||
int q_num_head;
|
int q_num_head;
|
||||||
@@ -527,9 +526,9 @@ cudaError_t BatchMLAWithPagedKVCacheKernelTraitsDispatched(Params& params,
|
|||||||
params.seq_lens_this_time,
|
params.seq_lens_this_time,
|
||||||
params.seq_lens_decoder,
|
params.seq_lens_decoder,
|
||||||
params.seq_lens_encoder,
|
params.seq_lens_encoder,
|
||||||
params.padding_offsets,
|
params.cumsum_q_seqlens,
|
||||||
|
params.batch_id_per_token,
|
||||||
reinterpret_cast<NV_TYPE*>(params.O),
|
reinterpret_cast<NV_TYPE*>(params.O),
|
||||||
params.max_seq_len,
|
|
||||||
params.chunk_num,
|
params.chunk_num,
|
||||||
params.q_num_head,
|
params.q_num_head,
|
||||||
params.chunk_size,
|
params.chunk_size,
|
||||||
|
@@ -255,9 +255,9 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
|
|||||||
const int * __restrict__ seq_lens_this_time,
|
const int * __restrict__ seq_lens_this_time,
|
||||||
const int * __restrict__ seq_lens_decoder,
|
const int * __restrict__ seq_lens_decoder,
|
||||||
const int * __restrict__ seq_lens_encoder,
|
const int * __restrict__ seq_lens_encoder,
|
||||||
const int * __restrict__ padding_offsets,
|
const int *__restrict__ cu_seqlens_q,
|
||||||
|
const int * __restrict__ batch_id_per_token,
|
||||||
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
T * __restrict__ out, // [token_num, num_heads, head_dim]
|
||||||
const int max_seq_len,
|
|
||||||
const int num_chunks,
|
const int num_chunks,
|
||||||
const int num_heads,
|
const int num_heads,
|
||||||
const int chunk_size,
|
const int chunk_size,
|
||||||
@@ -270,11 +270,10 @@ __global__ void merge_multi_chunks_kernel(const T * __restrict__ multi_out, // [
|
|||||||
__shared__ T smem[bdy * HEAD_DIM];
|
__shared__ T smem[bdy * HEAD_DIM];
|
||||||
__shared__ float md_smem[bdy * 2];
|
__shared__ float md_smem[bdy * 2];
|
||||||
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
for (int qid = blockIdx.x; qid < token_num; qid += gridDim.x) {
|
||||||
const uint32_t ori_token_id = qid + padding_offsets[qid];
|
const uint32_t bid = batch_id_per_token[qid];
|
||||||
const uint32_t bid = ori_token_id / max_seq_len;
|
|
||||||
const int seq_len_q = seq_lens_this_time[bid];
|
const int seq_len_q = seq_lens_this_time[bid];
|
||||||
if (seq_len_q == 0) continue;
|
if (seq_len_q == 0) continue;
|
||||||
const uint32_t local_seq_id = ori_token_id % max_seq_len;
|
const uint32_t local_seq_id = qid - cu_seqlens_q[bid];
|
||||||
int seq_len_kv = seq_lens_decoder[bid];
|
int seq_len_kv = seq_lens_decoder[bid];
|
||||||
if (seq_len_kv == 0) continue;
|
if (seq_len_kv == 0) continue;
|
||||||
seq_len_kv += seq_len_q;
|
seq_len_kv += seq_len_q;
|
||||||
|
@@ -25,7 +25,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
|||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& padding_offsets,
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::Tensor& encoder_batch_ids,
|
const paddle::Tensor& encoder_batch_ids,
|
||||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||||
@@ -97,7 +97,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
|||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
padding_offsets,
|
batch_id_per_token,
|
||||||
block_tables,
|
block_tables,
|
||||||
decoder_batch_ids,
|
decoder_batch_ids,
|
||||||
decoder_tile_ids_per_batch,
|
decoder_tile_ids_per_batch,
|
||||||
@@ -125,7 +125,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttentionKernel(
|
|||||||
out_linear_smooths,
|
out_linear_smooths,
|
||||||
seq_lens_this_time, // q_seq_len is 1
|
seq_lens_this_time, // q_seq_len is 1
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
padding_offsets,
|
batch_id_per_token,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
block_tables,
|
block_tables,
|
||||||
max_input_length,
|
max_input_length,
|
||||||
@@ -148,7 +148,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
|||||||
const paddle::Tensor& seq_lens_decoder,
|
const paddle::Tensor& seq_lens_decoder,
|
||||||
const paddle::Tensor& seq_lens_this_time,
|
const paddle::Tensor& seq_lens_this_time,
|
||||||
const paddle::Tensor& cu_seqlens_q,
|
const paddle::Tensor& cu_seqlens_q,
|
||||||
const paddle::Tensor& padding_offsets,
|
const paddle::Tensor& batch_id_per_token,
|
||||||
const paddle::Tensor& block_tables,
|
const paddle::Tensor& block_tables,
|
||||||
const paddle::Tensor& encoder_batch_ids,
|
const paddle::Tensor& encoder_batch_ids,
|
||||||
const paddle::Tensor& encoder_tile_ids_per_batch,
|
const paddle::Tensor& encoder_tile_ids_per_batch,
|
||||||
@@ -211,7 +211,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
|||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
padding_offsets,
|
batch_id_per_token,
|
||||||
block_tables,
|
block_tables,
|
||||||
encoder_batch_ids,
|
encoder_batch_ids,
|
||||||
encoder_tile_ids_per_batch,
|
encoder_tile_ids_per_batch,
|
||||||
@@ -257,7 +257,7 @@ std::vector<paddle::Tensor> MultiHeadLatentAttention(
|
|||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
padding_offsets,
|
batch_id_per_token,
|
||||||
block_tables,
|
block_tables,
|
||||||
encoder_batch_ids,
|
encoder_batch_ids,
|
||||||
encoder_tile_ids_per_batch,
|
encoder_tile_ids_per_batch,
|
||||||
@@ -310,7 +310,7 @@ std::vector<std::vector<int64_t>> MultiHeadLatentAttentionInferShape(
|
|||||||
const std::vector<int64_t>& seq_lens_decoder_shape,
|
const std::vector<int64_t>& seq_lens_decoder_shape,
|
||||||
const std::vector<int64_t>& seq_lens_this_time_shape,
|
const std::vector<int64_t>& seq_lens_this_time_shape,
|
||||||
const std::vector<int64_t>& cu_seqlens_q_shape,
|
const std::vector<int64_t>& cu_seqlens_q_shape,
|
||||||
const std::vector<int64_t>& padding_offsets_shape,
|
const std::vector<int64_t>& batch_id_per_token_shape,
|
||||||
const std::vector<int64_t>& block_tables_shape,
|
const std::vector<int64_t>& block_tables_shape,
|
||||||
const std::vector<int64_t>& encoder_batch_ids_shape,
|
const std::vector<int64_t>& encoder_batch_ids_shape,
|
||||||
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
const std::vector<int64_t>& encoder_tile_ids_per_batch_shape,
|
||||||
@@ -364,7 +364,7 @@ std::vector<paddle::DataType> MultiHeadLatentAttentionInferDtype(
|
|||||||
const paddle::DataType& seq_lens_decoder_dtype,
|
const paddle::DataType& seq_lens_decoder_dtype,
|
||||||
const paddle::DataType& seq_lens_this_time_dtype,
|
const paddle::DataType& seq_lens_this_time_dtype,
|
||||||
const paddle::DataType& cu_seqlens_q_dtype,
|
const paddle::DataType& cu_seqlens_q_dtype,
|
||||||
const paddle::DataType& padding_offsets_dtype,
|
const paddle::DataType& batch_id_per_token_dtype,
|
||||||
const paddle::DataType& block_tables_dtype,
|
const paddle::DataType& block_tables_dtype,
|
||||||
const paddle::DataType& encoder_batch_ids_dtype,
|
const paddle::DataType& encoder_batch_ids_dtype,
|
||||||
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
const paddle::DataType& encoder_tile_ids_per_batch_dtype,
|
||||||
@@ -418,7 +418,7 @@ PD_BUILD_OP(multi_head_latent_attention)
|
|||||||
"seq_lens_decoder",
|
"seq_lens_decoder",
|
||||||
"seq_lens_this_time",
|
"seq_lens_this_time",
|
||||||
"cu_seqlens_q",
|
"cu_seqlens_q",
|
||||||
"padding_offsets",
|
"batch_id_per_token",
|
||||||
"block_tables",
|
"block_tables",
|
||||||
"encoder_batch_ids",
|
"encoder_batch_ids",
|
||||||
"encoder_tile_ids_per_batch",
|
"encoder_tile_ids_per_batch",
|
||||||
|
@@ -213,6 +213,10 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.attention_metadata: AttentionMetadata = metadata
|
self.attention_metadata: AttentionMetadata = metadata
|
||||||
|
|
||||||
|
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||||
|
forward_meta.decoder_tile_ids_per_batch.copy_(
|
||||||
|
metadata.decoder_tile_ids_per_batch, False)
|
||||||
|
|
||||||
def get_attntion_meta(self) -> AttentionMetadata:
|
def get_attntion_meta(self) -> AttentionMetadata:
|
||||||
"""get_attntion_meta"""
|
"""get_attntion_meta"""
|
||||||
return self.attention_metadata
|
return self.attention_metadata
|
||||||
@@ -259,8 +263,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
latent_cache,
|
latent_cache,
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.padding_offset,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cum_offsets,
|
forward_meta.cu_seqlens_q,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
"none",
|
"none",
|
||||||
getattr(forward_meta, "max_input_length", -1),
|
getattr(forward_meta, "max_input_length", -1),
|
||||||
@@ -298,7 +302,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
"""
|
"""
|
||||||
metadata = self.attention_metadata
|
metadata = self.attention_metadata
|
||||||
|
|
||||||
if self.use_pd_disaggregation:
|
if self.pd_disaggregation_mode == "per_query":
|
||||||
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
|
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
|
||||||
metadata.kv_signal_metadata,
|
metadata.kv_signal_metadata,
|
||||||
layer.layer_id + self.start_layer_index,
|
layer.layer_id + self.start_layer_index,
|
||||||
@@ -317,8 +321,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
latent_cache,
|
latent_cache,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.padding_offset,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cum_offsets,
|
forward_meta.cu_seqlens_q,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
"none",
|
"none",
|
||||||
self.max_seq_len,
|
self.max_seq_len,
|
||||||
@@ -334,8 +338,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
forward_meta.padding_offset,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cum_offsets,
|
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
metadata.encoder_batch_ids,
|
metadata.encoder_batch_ids,
|
||||||
metadata.encoder_tile_ids_per_batch,
|
metadata.encoder_tile_ids_per_batch,
|
||||||
@@ -343,8 +346,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
metadata.kv_batch_ids,
|
metadata.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
metadata.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
metadata.kv_num_blocks,
|
||||||
metadata.decoder_batch_ids,
|
forward_meta.decoder_batch_ids,
|
||||||
metadata.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
metadata.decoder_num_blocks,
|
metadata.decoder_num_blocks,
|
||||||
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
||||||
metadata.max_enc_len_this_time,
|
metadata.max_enc_len_this_time,
|
||||||
@@ -394,7 +397,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
speculate_decoder = self.speculative_method is not None
|
speculate_decoder = self.speculative_method is not None
|
||||||
speculate_max_tokens = self.speculate_max_draft_token_num
|
speculate_max_tokens = self.speculate_max_draft_token_num
|
||||||
|
|
||||||
if self.use_pd_disaggregation:
|
if self.pd_disaggregation_mode == "per_query":
|
||||||
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
|
metadata.kv_signal_data_list[layer.layer_id] = init_signal_layerwise(
|
||||||
metadata.kv_signal_metadata,
|
metadata.kv_signal_metadata,
|
||||||
layer.layer_id + self.start_layer_index,
|
layer.layer_id + self.start_layer_index,
|
||||||
@@ -409,8 +412,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
latent_cache,
|
latent_cache,
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.padding_offset,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cum_offsets,
|
forward_meta.cu_seqlens_q,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
"none",
|
"none",
|
||||||
self.max_seq_len,
|
self.max_seq_len,
|
||||||
@@ -440,8 +443,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
latent_cache,
|
latent_cache,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.padding_offset,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cum_offsets,
|
forward_meta.cu_seqlens_q,
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
"none",
|
"none",
|
||||||
self.max_seq_len,
|
self.max_seq_len,
|
||||||
@@ -457,8 +460,7 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
forward_meta.cu_seqlens_q,
|
forward_meta.cu_seqlens_q,
|
||||||
forward_meta.padding_offset,
|
forward_meta.batch_id_per_token,
|
||||||
forward_meta.cum_offsets,
|
|
||||||
metadata.block_tables,
|
metadata.block_tables,
|
||||||
metadata.encoder_batch_ids,
|
metadata.encoder_batch_ids,
|
||||||
metadata.encoder_tile_ids_per_batch,
|
metadata.encoder_tile_ids_per_batch,
|
||||||
@@ -466,8 +468,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
metadata.kv_batch_ids,
|
metadata.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
metadata.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
metadata.kv_num_blocks,
|
||||||
metadata.decoder_batch_ids,
|
forward_meta.decoder_batch_ids,
|
||||||
metadata.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
metadata.decoder_num_blocks,
|
metadata.decoder_num_blocks,
|
||||||
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
||||||
metadata.max_enc_len_this_time,
|
metadata.max_enc_len_this_time,
|
||||||
|
Reference in New Issue
Block a user