[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)

* reset decoder_block_shape_q buffer

* refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch

* update decode_max_tile_size

* fix pre-commit

* update block_multihead_attn_backend

* update flas attn backend

* update MLA Attention

* update XPU Attention

* update gcu,iluvatar model runner

* Update MTP

* fix MTP bug
This commit is contained in:
RAM
2025-07-31 00:09:31 +08:00
committed by GitHub
parent 998968f1e8
commit d850660872
13 changed files with 222 additions and 235 deletions

View File

@@ -195,22 +195,25 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int group_size, const int block_size,
const int decoder_step_token_num) {
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num)
{
auto stream = seq_lens_encoder.stream();
int bsz = seq_lens_this_time.shape()[0];
auto max_len_tensor =
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
max_len_tensor, bsz);
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
// max_just_dec_merged_len_this_time, max_system_len,
// max_just_dec_len_without_system
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
auto max_len_cpu_ptr = max_len_cpu.data<int>();
paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
max_len_tensor_gpu, bsz);
max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
int max_len_this_time = max_len_cpu_ptr[0];
int max_enc_len_this_time = max_len_cpu_ptr[1];
int max_dec_len_this_time = max_len_cpu_ptr[2];
@@ -222,14 +225,11 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
paddle::Tensor encoder_batch_ids;
paddle::Tensor encoder_tile_ids_per_batch;
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor kv_batch_ids;
paddle::Tensor kv_tile_ids_per_batch;
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
paddle::Tensor decoder_batch_ids;
paddle::Tensor decoder_tile_ids_per_batch;
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
@@ -291,92 +291,64 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
}
if (max_just_dec_len_this_time > 0) {
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_batch_ids =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
paddle::DataType::INT32, seq_lens_encoder.place());
if (max_just_dec_len_this_time > 0) {
// Clear buffer
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size);
decoder_num_blocks_x_cpu =
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
decoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
}
return {encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu /*cpu*/,
max_len_cpu};
}
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const paddle::DataType &seq_lens_encoder_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_this_time_dtype) {
return {
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
paddle::DataType::INT32, paddle::DataType::INT32};
}
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
const std::vector<int64_t> &seq_lens_encoder_shape,
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_this_time_shape) {
std::vector<int64_t> dynamic_shape = {-1};
return {dynamic_shape,
dynamic_shape,
{1},
dynamic_shape,
dynamic_shape,
{1},
dynamic_shape,
dynamic_shape,
{1},
{1},
{8}};
encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu, /*cpu*/
};
}
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
.Outputs({paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"),
paddle::Optional("encoder_num_blocks"),
paddle::Optional("kv_batch_ids"),
paddle::Optional("kv_tile_ids_per_batch"),
paddle::Optional("kv_num_blocks"),
paddle::Optional("decoder_batch_ids"),
paddle::Optional("decoder_tile_ids_per_batch"),
paddle::Optional("decoder_num_blocks"),
paddle::Optional("max_len_kv"), "set_max_lengths"})
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
"group_size: int", "block_size: int",
"decoder_step_token_num: int"})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
.Inputs({
"seq_lens_encoder",
"seq_lens_decoder",
"seq_lens_this_time",
"decoder_batch_ids",
"decoder_tile_ids_per_batch",
"decoder_num_blocks_x_cpu",
"max_len_tensor_cpu"
})
.Outputs({
paddle::Optional("encoder_batch_ids"),
paddle::Optional("encoder_tile_ids_per_batch"),
paddle::Optional("encoder_num_blocks_x_cpu"),
paddle::Optional("kv_batch_ids"),
paddle::Optional("kv_tile_ids_per_batch"),
paddle::Optional("kv_num_blocks_x_cpu"),
"max_len_kv_cpu"
})
.Attrs({
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"group_size: int",
"block_size: int",
"decoder_step_token_num: int"
})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));

View File

@@ -235,8 +235,14 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time,
const int encoder_block_shape_q, const int decoder_block_shape_q,
const int group_size, const int block_size,
paddle::Tensor &decoder_batch_ids, // Inplace
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,

View File

@@ -77,6 +77,10 @@ class ForwardMeta:
decoder_batch_ids: Optional[paddle.Tensor] = None
# Tile ID for each batch of the decoder. Used by attention backend.
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
# The number of blocks that attention backend can use in decode stage
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
# Recorded multiple lengths related to prefill or decode
max_len_tensor_cpu: Optional[paddle.Tensor] = None
# Sequence length of encoder for ever batch
seq_lens_encoder: Optional[paddle.Tensor] = None

View File

@@ -48,17 +48,13 @@ class AppendAttentionMetadata(AttentionMetadata):
AppendAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
@@ -66,8 +62,6 @@ class AppendAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
@@ -89,6 +83,8 @@ class AppendAttentionBackend(AttentionBackend):
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
) -> None:
"""
AppendAttentionBackend __init__
@@ -110,9 +106,12 @@ class AppendAttentionBackend(AttentionBackend):
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
@@ -126,8 +125,6 @@ class AppendAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = AppendAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.max_partition_size = self.max_partition_size
metadata.encoder_max_partition_size = self.max_seq_len
metadata._dtype = paddle.get_default_dtype()
@@ -148,18 +145,18 @@ class AppendAttentionBackend(AttentionBackend):
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids, # will copy to buffer
metadata.decoder_tile_ids_per_batch, # will copy to buffer
metadata.decoder_num_blocks,
metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
@@ -181,8 +178,6 @@ class AppendAttentionBackend(AttentionBackend):
)
self.attention_metadata: AttentionMetadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
@@ -249,10 +244,10 @@ class AppendAttentionBackend(AttentionBackend):
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
forward_meta.decoder_batch_ids, # from buffer
forward_meta.decoder_tile_ids_per_batch, # from buffer
metadata.decoder_num_blocks,
metadata.set_max_lengths,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
metadata.max_len_kv,
metadata.rotary_embs,
metadata.attn_mask,
@@ -275,8 +270,8 @@ class AppendAttentionBackend(AttentionBackend):
getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0),
metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
metadata.max_partition_size,
metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1,

View File

@@ -38,17 +38,12 @@ class BlockAttentionMetadata(AttentionMetadata):
BlockAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
@@ -56,8 +51,6 @@ class BlockAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
@@ -79,6 +72,8 @@ class BlockAttentionBackend(AttentionBackend):
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
):
"""
BlockAttentionBackend __init__

View File

@@ -53,8 +53,6 @@ class FlashAttentionMetadata(AttentionMetadata):
FlashAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None
encoder_batch_ids: paddle.Tensor = None
@@ -63,12 +61,6 @@ class FlashAttentionMetadata(AttentionMetadata):
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None
@@ -100,6 +92,8 @@ class FlashAttentionBackend(AttentionBackend):
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
):
"""
FlashAttentionBackend __init__
@@ -111,10 +105,13 @@ class FlashAttentionBackend(AttentionBackend):
self.kv_num_heads = kv_num_heads
self.num_heads = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim = fd_config.model_config.head_dim
self.attn_outputsize_tp = self.num_heads * self.head_dim
self.block_size = fd_config.cache_config.block_size
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None
@@ -176,8 +173,6 @@ class FlashAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables
@@ -188,18 +183,18 @@ class FlashAttentionBackend(AttentionBackend):
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
@@ -233,8 +228,6 @@ class FlashAttentionBackend(AttentionBackend):
self.rank, int(self.device_id), self.keep_pd_step_flag
)
self.attention_metadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
def forward_mixed(
self,
@@ -291,8 +284,8 @@ class FlashAttentionBackend(AttentionBackend):
v,
metadata.cu_seqlens_q,
metadata.cu_seqlens_k,
max_seqlen_q=metadata.set_max_lengths[0],
max_seqlen_k=metadata.set_max_lengths[3],
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
causal=self.causal,
**self.flash_attn_kwargs,
)[0].reshape([-1, self.attn_outputsize_tp])

View File

@@ -64,17 +64,13 @@ class MLAAttentionMetadata(AttentionMetadata):
MLAAttentionMetadata for Multi-Layer Attention
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
max_len_kv: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
@@ -82,8 +78,6 @@ class MLAAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
@@ -105,6 +99,8 @@ class MLAAttentionBackend(AttentionBackend):
kv_num_heads: int,
num_heads: int,
head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
) -> None:
"""
MLAAttentionBackend __init__
@@ -128,8 +124,11 @@ class MLAAttentionBackend(AttentionBackend):
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
# For Multi Head Latent Attention
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
@@ -152,8 +151,6 @@ class MLAAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
metadata = MLAAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.max_partition_size = 32768
metadata.encoder_max_partition_size = self.max_seq_len
metadata._dtype = paddle.get_default_dtype()
@@ -176,27 +173,25 @@ class MLAAttentionBackend(AttentionBackend):
metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time,
metadata.encoder_block_shape_q,
metadata.decoder_block_shape_q,
self.num_heads // self.kv_num_heads,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size,
self.speculate_max_draft_token_num + 1,
)
# MLA
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1]
forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2]
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
# pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers
@@ -216,9 +211,6 @@ class MLAAttentionBackend(AttentionBackend):
self.attention_metadata: AttentionMetadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
return self.attention_metadata
@@ -354,8 +346,8 @@ class MLAAttentionBackend(AttentionBackend):
metadata.kv_num_blocks,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time,
metadata.max_len_kv,
@@ -476,8 +468,8 @@ class MLAAttentionBackend(AttentionBackend):
metadata.kv_num_blocks,
forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
forward_meta.decoder_num_blocks_cpu,
forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time,
metadata.max_len_kv,

View File

@@ -28,6 +28,10 @@ def get_block_shape_and_split_kv_block(
seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor,
decoder_batch_ids: paddle.Tensor,
decoder_tile_ids_per_batch: paddle.Tensor,
decoder_num_blocks_x_cpu: paddle.Tensor,
max_len_tensor_cpu: paddle.Tensor,
encoder_block_shape_q: int,
decoder_block_shape_q: int,
group_size: int,
@@ -45,15 +49,15 @@ def get_block_shape_and_split_kv_block(
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_kv,
set_max_lengths,
max_len_kv_cpu,
) = get_block_shape_and_split_kv_block_cuda(
seq_lens_encoder,
seq_lens_decoder,
seq_lens_this_time,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu,
max_len_tensor_cpu,
encoder_block_shape_q,
decoder_block_shape_q,
group_size,
@@ -67,11 +71,7 @@ def get_block_shape_and_split_kv_block(
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_kv,
set_max_lengths,
max_len_kv_cpu,
)
else:
raise NotImplementedError

View File

@@ -44,26 +44,13 @@ class XPUAttentionMetadata(AttentionMetadata):
XPUAttentionMetadata
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation
@@ -91,7 +78,6 @@ class XPUAttentionBackend(AttentionBackend):
"""
super().__init__()
self.attention_metadata: XPUAttentionMetadata = None
# TODO(gongshaotian): Use fd_config parameters in the correct location
self.block_size: int = fd_config.cache_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = (
@@ -99,9 +85,6 @@ class XPUAttentionBackend(AttentionBackend):
)
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.causal: bool = getattr(fd_config.model_config, "causal", True)
# self.speculate_method = fd_config.parallel_config.speculate_method
# self.use_speculate = self.speculate_method is not None
# self.speculate_max_draft_token_num = fd_config.parallel_config.speculate_max_draft_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
@@ -117,8 +100,6 @@ class XPUAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = XPUAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.max_partition_size = 32768
metadata.encoder_max_partition_size = 32768
metadata._dtype = paddle.get_default_dtype()

View File

@@ -184,13 +184,26 @@ class MTPProposer(Proposer):
"""
assert len(self.attn_backends) == 0
# TODO(gongshaotian): Get rank from config
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
self.model_config.kv_num_heads = (
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size
self.model_config.kv_num_heads = max(
1,
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size,
)
head_dim = self.model_config.head_dim
# Initialize AttentionBackend buffers
encoder_block_shape_q = 64
decoder_block_shape_q = 16
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"])
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
self.main_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.main_model_inputs["decoder_num_blocks_cpu"]
).pin_memory()
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu()
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
@@ -198,6 +211,8 @@ class MTPProposer(Proposer):
kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads,
head_dim=head_dim,
encoder_block_shape_q=encoder_block_shape_q,
decoder_block_shape_q=decoder_block_shape_q,
)
if attn_backend is None:
raise NotImplementedError(
@@ -293,6 +308,12 @@ class MTPProposer(Proposer):
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
self.model_inputs["substep"] = 0
# Declare AttentionBackend buffers
self.model_inputs["decoder_batch_ids"] = None
self.model_inputs["decoder_tile_ids_per_batch"] = None
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.model_inputs["max_len_tensor_cpu"] = None # CPU
# Input tokens
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
@@ -405,6 +426,8 @@ class MTPProposer(Proposer):
attn_backend=self.attn_backends[0],
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"],
max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
seq_lens_this_time=self.model_inputs["seq_lens_this_time"],

View File

@@ -417,9 +417,11 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
# AttentionBackend buffers
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
# Declare AttentionBackend buffers
self.share_inputs["decoder_batch_ids"] = None
self.share_inputs["decoder_tile_ids_per_batch"] = None
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.share_inputs["max_len_tensor_cpu"] = None # CPU
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
@@ -579,6 +581,8 @@ class GCUModelRunner(ModelRunnerBase):
attn_backend=self.attn_backends[0],
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
@@ -655,6 +659,18 @@ class GCUModelRunner(ModelRunnerBase):
)
head_dim = self.model_config.head_dim
# Initialize AttentionBackend buffers
encoder_block_shape_q = 64
decoder_block_shape_q = 16
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
)
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
@@ -662,6 +678,8 @@ class GCUModelRunner(ModelRunnerBase):
kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads,
head_dim=head_dim,
encoder_block_shape_q=encoder_block_shape_q,
decoder_block_shape_q=decoder_block_shape_q,
)
if attn_backend is None:
raise NotImplementedError(
@@ -1179,9 +1197,5 @@ class GCUModelRunner(ModelRunnerBase):
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
"""
# TODO(gongshaotian): Use more efficient implementation
if self.forward_meta.step_use_cudagraph:
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
for i in range(1, num_empty_batch + 1):
self.forward_meta.decoder_batch_ids[-i] = 0
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
# In init_attention_metadata, the decode buffer has already been cleared
return

View File

@@ -610,9 +610,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["not_need_stop"] = paddle.full(
[1], False, dtype="bool"
).cpu() # TODO(gongshaotian): move to pinnd memory
self.share_inputs["not_need_stop"] = paddle.full([1], False, dtype="bool").cpu()
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
@@ -643,9 +641,12 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
# AttentionBackend buffers
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
# Declare AttentionBackend buffers
self.share_inputs["decoder_batch_ids"] = None
self.share_inputs["decoder_tile_ids_per_batch"] = None
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.share_inputs["max_len_tensor_cpu"] = None # CPU
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
@@ -845,6 +846,8 @@ class GPUModelRunner(ModelRunnerBase):
attn_backend=self.attn_backends[0],
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
@@ -856,7 +859,6 @@ class GPUModelRunner(ModelRunnerBase):
)
# Update Batch type for cuda graph
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
only_decode_batch = True
prefill_exists = None
# mix ep in single node
@@ -946,6 +948,18 @@ class GPUModelRunner(ModelRunnerBase):
)
head_dim = self.model_config.head_dim
# Initialize AttentionBackend buffers
encoder_block_shape_q = 64
decoder_block_shape_q = 16
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
)
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
# Get the attention backend
attn_cls = get_attention_backend()
attn_backend = attn_cls(
@@ -953,6 +967,8 @@ class GPUModelRunner(ModelRunnerBase):
kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads,
head_dim=head_dim,
encoder_block_shape_q=encoder_block_shape_q,
decoder_block_shape_q=decoder_block_shape_q,
)
self.attn_backends.append(attn_backend)
@@ -1527,12 +1543,8 @@ class GPUModelRunner(ModelRunnerBase):
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
"""
# TODO(gongshaotian): Use more efficient implementation
if self.forward_meta.step_use_cudagraph:
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
for i in range(1, num_empty_batch + 1):
self.forward_meta.decoder_batch_ids[-i] = 0
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
# In init_attention_metadata, the decode buffer has already been cleared
return
def _init_image_preprocess(self) -> None:
processor = DataProcessor(

View File

@@ -383,8 +383,8 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
# AttentionBackend buffers
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["decoder_batch_ids"] = None
self.share_inputs["decoder_tile_ids_per_batch"] = None
# Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))