[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)

* reset decoder_block_shape_q buffer

* refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch

* update decode_max_tile_size

* fix pre-commit

* update block_multihead_attn_backend

* update flas attn backend

* update MLA Attention

* update XPU Attention

* update gcu,iluvatar model runner

* Update MTP

* fix MTP bug
This commit is contained in:
RAM
2025-07-31 00:09:31 +08:00
committed by GitHub
parent 998968f1e8
commit d850660872
13 changed files with 222 additions and 235 deletions

View File

@@ -195,22 +195,25 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &seq_lens_this_time,
const int encoder_block_shape_q, const int decoder_block_shape_q, paddle::Tensor &decoder_batch_ids, // Inplace
const int group_size, const int block_size, paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
const int decoder_step_token_num) { paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num)
{
auto stream = seq_lens_encoder.stream(); auto stream = seq_lens_encoder.stream();
int bsz = seq_lens_this_time.shape()[0]; int bsz = seq_lens_this_time.shape()[0];
auto max_len_tensor =
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
max_len_tensor, bsz);
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time, paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
// max_enc_dec_len_this_time, max_just_dec_len_this_time, GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
// max_just_dec_merged_len_this_time, max_system_len, max_len_tensor_gpu, bsz);
// max_just_dec_len_without_system max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
auto max_len_cpu_ptr = max_len_cpu.data<int>(); auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
int max_len_this_time = max_len_cpu_ptr[0]; int max_len_this_time = max_len_cpu_ptr[0];
int max_enc_len_this_time = max_len_cpu_ptr[1]; int max_enc_len_this_time = max_len_cpu_ptr[1];
int max_dec_len_this_time = max_len_cpu_ptr[2]; int max_dec_len_this_time = max_len_cpu_ptr[2];
@@ -222,14 +225,11 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
paddle::Tensor encoder_batch_ids; paddle::Tensor encoder_batch_ids;
paddle::Tensor encoder_tile_ids_per_batch; paddle::Tensor encoder_tile_ids_per_batch;
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/ paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor kv_batch_ids; paddle::Tensor kv_batch_ids;
paddle::Tensor kv_tile_ids_per_batch; paddle::Tensor kv_tile_ids_per_batch;
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/ paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
paddle::Tensor decoder_batch_ids; paddle::Tensor max_len_kv_cpu; /*cpu*/
paddle::Tensor decoder_tile_ids_per_batch;
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
paddle::Tensor max_len_kv_cpu; /*cpu*/
auto max_len_kv = auto max_len_kv =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place()); GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
@@ -291,92 +291,64 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
kv_num_blocks_x_cpu = kv_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place()); GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
} }
if (max_just_dec_len_this_time > 0) {
const uint32_t decoder_max_tile_size_per_bs_q =
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_batch_ids = if (max_just_dec_len_this_time > 0) {
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, // Clear buffer
paddle::DataType::INT32, seq_lens_encoder.place()); const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
decoder_tile_ids_per_batch = const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q}, PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
paddle::DataType::INT32, seq_lens_encoder.place()); PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
auto decoder_num_blocks_x = auto decoder_num_blocks_x =
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place()); GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
split_q_block<<<1, 32, 0, stream>>>( split_q_block<<<1, 32, 0, stream>>>(
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(), seq_lens_this_time.data<int>(),
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(), seq_lens_encoder.data<int>(),
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q, decoder_batch_ids.data<int>(),
decoder_tile_ids_per_batch.data<int>(),
decoder_num_blocks_x.data<int>(),
bsz,
decoder_block_shape_q,
group_size); group_size);
decoder_num_blocks_x_cpu = decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
} else {
decoder_batch_ids =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_tile_ids_per_batch =
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
decoder_num_blocks_x_cpu =
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
} }
return {encoder_batch_ids,
encoder_tile_ids_per_batch,
encoder_num_blocks_x_cpu, /*cpu*/
kv_batch_ids,
kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu, /*cpu*/
max_len_kv_cpu /*cpu*/,
max_len_cpu};
}
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
const paddle::DataType &seq_lens_encoder_dtype,
const paddle::DataType &seq_lens_decoder_dtype,
const paddle::DataType &seq_lens_this_time_dtype) {
return { return {
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, encoder_batch_ids,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, encoder_tile_ids_per_batch,
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32, encoder_num_blocks_x_cpu, /*cpu*/
paddle::DataType::INT32, paddle::DataType::INT32}; kv_batch_ids,
} kv_tile_ids_per_batch,
kv_num_blocks_x_cpu, /*cpu*/
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape( max_len_kv_cpu, /*cpu*/
const std::vector<int64_t> &seq_lens_encoder_shape, };
const std::vector<int64_t> &seq_lens_decoder_shape,
const std::vector<int64_t> &seq_lens_this_time_shape) {
std::vector<int64_t> dynamic_shape = {-1};
return {dynamic_shape,
dynamic_shape,
{1},
dynamic_shape,
dynamic_shape,
{1},
dynamic_shape,
dynamic_shape,
{1},
{1},
{8}};
} }
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block) PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"}) .Inputs({
.Outputs({paddle::Optional("encoder_batch_ids"), "seq_lens_encoder",
paddle::Optional("encoder_tile_ids_per_batch"), "seq_lens_decoder",
paddle::Optional("encoder_num_blocks"), "seq_lens_this_time",
paddle::Optional("kv_batch_ids"), "decoder_batch_ids",
paddle::Optional("kv_tile_ids_per_batch"), "decoder_tile_ids_per_batch",
paddle::Optional("kv_num_blocks"), "decoder_num_blocks_x_cpu",
paddle::Optional("decoder_batch_ids"), "max_len_tensor_cpu"
paddle::Optional("decoder_tile_ids_per_batch"), })
paddle::Optional("decoder_num_blocks"), .Outputs({
paddle::Optional("max_len_kv"), "set_max_lengths"}) paddle::Optional("encoder_batch_ids"),
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int", paddle::Optional("encoder_tile_ids_per_batch"),
"group_size: int", "block_size: int", paddle::Optional("encoder_num_blocks_x_cpu"),
"decoder_step_token_num: int"}) paddle::Optional("kv_batch_ids"),
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock)) paddle::Optional("kv_tile_ids_per_batch"),
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape)) paddle::Optional("kv_num_blocks_x_cpu"),
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype)); "max_len_kv_cpu"
})
.Attrs({
"encoder_block_shape_q: int",
"decoder_block_shape_q: int",
"group_size: int",
"block_size: int",
"decoder_step_token_num: int"
})
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));

View File

@@ -235,8 +235,14 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_decoder,
const paddle::Tensor &seq_lens_this_time, const paddle::Tensor &seq_lens_this_time,
const int encoder_block_shape_q, const int decoder_block_shape_q, paddle::Tensor &decoder_batch_ids, // Inplace
const int group_size, const int block_size, paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
const int encoder_block_shape_q,
const int decoder_block_shape_q,
const int group_size,
const int block_size,
const int decoder_step_token_num); const int decoder_step_token_num);
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids, std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,

View File

@@ -77,6 +77,10 @@ class ForwardMeta:
decoder_batch_ids: Optional[paddle.Tensor] = None decoder_batch_ids: Optional[paddle.Tensor] = None
# Tile ID for each batch of the decoder. Used by attention backend. # Tile ID for each batch of the decoder. Used by attention backend.
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
# The number of blocks that attention backend can use in decode stage
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
# Recorded multiple lengths related to prefill or decode
max_len_tensor_cpu: Optional[paddle.Tensor] = None
# Sequence length of encoder for ever batch # Sequence length of encoder for ever batch
seq_lens_encoder: Optional[paddle.Tensor] = None seq_lens_encoder: Optional[paddle.Tensor] = None

View File

@@ -48,17 +48,13 @@ class AppendAttentionMetadata(AttentionMetadata):
AppendAttentionMetadata AppendAttentionMetadata
""" """
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None max_len_kv: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16 _dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768 encoder_max_partition_size: int = 32768
@@ -66,8 +62,6 @@ class AppendAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16" _fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation # pd_disaggregation
@@ -89,6 +83,8 @@ class AppendAttentionBackend(AttentionBackend):
kv_num_heads: int, kv_num_heads: int,
num_heads: int, num_heads: int,
head_dim: int, head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
) -> None: ) -> None:
""" """
AppendAttentionBackend __init__ AppendAttentionBackend __init__
@@ -110,9 +106,12 @@ class AppendAttentionBackend(AttentionBackend):
self.kv_num_heads: int = kv_num_heads self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads self.num_heads: int = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim: int = fd_config.model_config.head_dim self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers self.num_layers: int = fd_config.model_config.num_hidden_layers
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768)) self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
@@ -126,8 +125,6 @@ class AppendAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta): def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it.""" """Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = AppendAttentionMetadata() metadata = AppendAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.max_partition_size = self.max_partition_size metadata.max_partition_size = self.max_partition_size
metadata.encoder_max_partition_size = self.max_seq_len metadata.encoder_max_partition_size = self.max_seq_len
metadata._dtype = paddle.get_default_dtype() metadata._dtype = paddle.get_default_dtype()
@@ -148,18 +145,18 @@ class AppendAttentionBackend(AttentionBackend):
metadata.kv_batch_ids, metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch, metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks, metadata.kv_num_blocks,
metadata.decoder_batch_ids, # will copy to buffer
metadata.decoder_tile_ids_per_batch, # will copy to buffer
metadata.decoder_num_blocks,
metadata.max_len_kv, metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block( ) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
metadata.encoder_block_shape_q, forward_meta.decoder_batch_ids,
metadata.decoder_block_shape_q, forward_meta.decoder_tile_ids_per_batch,
self.num_heads // self.kv_num_heads, forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size, self.block_size,
self.speculate_max_draft_token_num + 1, self.speculate_max_draft_token_num + 1,
) )
@@ -181,8 +178,6 @@ class AppendAttentionBackend(AttentionBackend):
) )
self.attention_metadata: AttentionMetadata = metadata self.attention_metadata: AttentionMetadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
def get_attntion_meta(self) -> AttentionMetadata: def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta""" """get_attntion_meta"""
@@ -249,10 +244,10 @@ class AppendAttentionBackend(AttentionBackend):
metadata.kv_batch_ids, metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch, metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks, metadata.kv_num_blocks,
forward_meta.decoder_batch_ids, # from buffer forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch, # from buffer forward_meta.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks, forward_meta.decoder_num_blocks_cpu,
metadata.set_max_lengths, forward_meta.max_len_tensor_cpu,
metadata.max_len_kv, metadata.max_len_kv,
metadata.rotary_embs, metadata.rotary_embs,
metadata.attn_mask, metadata.attn_mask,
@@ -275,8 +270,8 @@ class AppendAttentionBackend(AttentionBackend):
getattr(layer, "quant_max_bound", 0.0), getattr(layer, "quant_max_bound", 0.0),
getattr(layer, "quant_min_bound", 0.0), getattr(layer, "quant_min_bound", 0.0),
getattr(layer, "out_scale", -1.0), getattr(layer, "out_scale", -1.0),
metadata.encoder_block_shape_q, self.encoder_block_shape_q,
metadata.decoder_block_shape_q, self.decoder_block_shape_q,
metadata.max_partition_size, metadata.max_partition_size,
metadata.encoder_max_partition_size, metadata.encoder_max_partition_size,
self.speculate_max_draft_token_num + 1, self.speculate_max_draft_token_num + 1,

View File

@@ -38,17 +38,12 @@ class BlockAttentionMetadata(AttentionMetadata):
BlockAttentionMetadata BlockAttentionMetadata
""" """
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16 _dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768 encoder_max_partition_size: int = 32768
@@ -56,8 +51,6 @@ class BlockAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16" _fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation # pd_disaggregation
@@ -79,6 +72,8 @@ class BlockAttentionBackend(AttentionBackend):
kv_num_heads: int, kv_num_heads: int,
num_heads: int, num_heads: int,
head_dim: int, head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
): ):
""" """
BlockAttentionBackend __init__ BlockAttentionBackend __init__

View File

@@ -53,8 +53,6 @@ class FlashAttentionMetadata(AttentionMetadata):
FlashAttentionMetadata FlashAttentionMetadata
""" """
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
encoder_batch_ids: paddle.Tensor = None encoder_batch_ids: paddle.Tensor = None
@@ -63,12 +61,6 @@ class FlashAttentionMetadata(AttentionMetadata):
kv_batch_ids: paddle.Tensor = None kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
cu_seqlens_q: paddle.Tensor = None cu_seqlens_q: paddle.Tensor = None
cu_seqlens_k: paddle.Tensor = None cu_seqlens_k: paddle.Tensor = None
@@ -100,6 +92,8 @@ class FlashAttentionBackend(AttentionBackend):
kv_num_heads: int, kv_num_heads: int,
num_heads: int, num_heads: int,
head_dim: int, head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
): ):
""" """
FlashAttentionBackend __init__ FlashAttentionBackend __init__
@@ -111,10 +105,13 @@ class FlashAttentionBackend(AttentionBackend):
self.kv_num_heads = kv_num_heads self.kv_num_heads = kv_num_heads
self.num_heads = num_heads self.num_heads = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim = fd_config.model_config.head_dim self.head_dim = fd_config.model_config.head_dim
self.attn_outputsize_tp = self.num_heads * self.head_dim self.attn_outputsize_tp = self.num_heads * self.head_dim
self.block_size = fd_config.cache_config.block_size self.block_size = fd_config.cache_config.block_size
self.num_layers: int = fd_config.model_config.num_hidden_layers self.num_layers: int = fd_config.model_config.num_hidden_layers
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
self.speculative_method = fd_config.speculative_config.method self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None self.use_speculate = self.speculative_method is not None
@@ -176,8 +173,6 @@ class FlashAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta): def init_attention_metadata(self, forward_meta: ForwardMeta):
metadata = FlashAttentionMetadata() metadata = FlashAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
metadata.rotary_embs = forward_meta.rotary_embs metadata.rotary_embs = forward_meta.rotary_embs
metadata.block_tables = forward_meta.block_tables metadata.block_tables = forward_meta.block_tables
@@ -188,18 +183,18 @@ class FlashAttentionBackend(AttentionBackend):
metadata.kv_batch_ids, metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch, metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks, metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.max_len_kv, metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block( ) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
metadata.encoder_block_shape_q, forward_meta.decoder_batch_ids,
metadata.decoder_block_shape_q, forward_meta.decoder_tile_ids_per_batch,
self.num_heads // self.kv_num_heads, forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size, self.block_size,
self.speculate_max_draft_token_num + 1, self.speculate_max_draft_token_num + 1,
) )
@@ -233,8 +228,6 @@ class FlashAttentionBackend(AttentionBackend):
self.rank, int(self.device_id), self.keep_pd_step_flag self.rank, int(self.device_id), self.keep_pd_step_flag
) )
self.attention_metadata = metadata self.attention_metadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
def forward_mixed( def forward_mixed(
self, self,
@@ -291,8 +284,8 @@ class FlashAttentionBackend(AttentionBackend):
v, v,
metadata.cu_seqlens_q, metadata.cu_seqlens_q,
metadata.cu_seqlens_k, metadata.cu_seqlens_k,
max_seqlen_q=metadata.set_max_lengths[0], max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
max_seqlen_k=metadata.set_max_lengths[3], max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
causal=self.causal, causal=self.causal,
**self.flash_attn_kwargs, **self.flash_attn_kwargs,
)[0].reshape([-1, self.attn_outputsize_tp]) )[0].reshape([-1, self.attn_outputsize_tp])

View File

@@ -64,17 +64,13 @@ class MLAAttentionMetadata(AttentionMetadata):
MLAAttentionMetadata for Multi-Layer Attention MLAAttentionMetadata for Multi-Layer Attention
""" """
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None max_len_kv: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16 _dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768 encoder_max_partition_size: int = 32768
@@ -82,8 +78,6 @@ class MLAAttentionMetadata(AttentionMetadata):
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16" _fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation # pd_disaggregation
@@ -105,6 +99,8 @@ class MLAAttentionBackend(AttentionBackend):
kv_num_heads: int, kv_num_heads: int,
num_heads: int, num_heads: int,
head_dim: int, head_dim: int,
encoder_block_shape_q: int = -1,
decoder_block_shape_q: int = -1,
) -> None: ) -> None:
""" """
MLAAttentionBackend __init__ MLAAttentionBackend __init__
@@ -128,8 +124,11 @@ class MLAAttentionBackend(AttentionBackend):
self.kv_num_heads: int = kv_num_heads self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads self.num_heads: int = num_heads
self.group_size: int = self.num_heads // self.kv_num_heads
self.head_dim: int = fd_config.model_config.head_dim self.head_dim: int = fd_config.model_config.head_dim
self.num_layers: int = fd_config.model_config.num_hidden_layers self.num_layers: int = fd_config.model_config.num_hidden_layers
self.encoder_block_shape_q: int = encoder_block_shape_q
self.decoder_block_shape_q: int = decoder_block_shape_q
# For Multi Head Latent Attention # For Multi Head Latent Attention
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
@@ -152,8 +151,6 @@ class MLAAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta): def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it.""" """Initialize attention metadata hence all layers in the forward pass can reuse it."""
metadata = MLAAttentionMetadata() metadata = MLAAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.max_partition_size = 32768 metadata.max_partition_size = 32768
metadata.encoder_max_partition_size = self.max_seq_len metadata.encoder_max_partition_size = self.max_seq_len
metadata._dtype = paddle.get_default_dtype() metadata._dtype = paddle.get_default_dtype()
@@ -176,27 +173,25 @@ class MLAAttentionBackend(AttentionBackend):
metadata.kv_batch_ids, metadata.kv_batch_ids,
metadata.kv_tile_ids_per_batch, metadata.kv_tile_ids_per_batch,
metadata.kv_num_blocks, metadata.kv_num_blocks,
metadata.decoder_batch_ids,
metadata.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks,
metadata.max_len_kv, metadata.max_len_kv,
metadata.set_max_lengths,
) = get_block_shape_and_split_kv_block( ) = get_block_shape_and_split_kv_block(
forward_meta.seq_lens_encoder, forward_meta.seq_lens_encoder,
forward_meta.seq_lens_decoder, forward_meta.seq_lens_decoder,
forward_meta.seq_lens_this_time, forward_meta.seq_lens_this_time,
metadata.encoder_block_shape_q, forward_meta.decoder_batch_ids,
metadata.decoder_block_shape_q, forward_meta.decoder_tile_ids_per_batch,
self.num_heads // self.kv_num_heads, forward_meta.decoder_num_blocks_cpu,
forward_meta.max_len_tensor_cpu,
self.encoder_block_shape_q,
self.decoder_block_shape_q,
self.group_size,
self.block_size, self.block_size,
self.speculate_max_draft_token_num + 1, self.speculate_max_draft_token_num + 1,
) )
# MLA # MLA
metadata.max_enc_len_this_time = metadata.set_max_lengths[1] metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
metadata.max_dec_len_this_time = metadata.set_max_lengths[2] metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1]
forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2]
# pd_disaggregation # pd_disaggregation
metadata.kv_signal_data_list = [None] * self.num_layers metadata.kv_signal_data_list = [None] * self.num_layers
@@ -216,9 +211,6 @@ class MLAAttentionBackend(AttentionBackend):
self.attention_metadata: AttentionMetadata = metadata self.attention_metadata: AttentionMetadata = metadata
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
def get_attntion_meta(self) -> AttentionMetadata: def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta""" """get_attntion_meta"""
return self.attention_metadata return self.attention_metadata
@@ -354,8 +346,8 @@ class MLAAttentionBackend(AttentionBackend):
metadata.kv_num_blocks, metadata.kv_num_blocks,
forward_meta.decoder_batch_ids, forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks, forward_meta.decoder_num_blocks_cpu,
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time, metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time, metadata.max_dec_len_this_time,
metadata.max_len_kv, metadata.max_len_kv,
@@ -476,8 +468,8 @@ class MLAAttentionBackend(AttentionBackend):
metadata.kv_num_blocks, metadata.kv_num_blocks,
forward_meta.decoder_batch_ids, forward_meta.decoder_batch_ids,
forward_meta.decoder_tile_ids_per_batch, forward_meta.decoder_tile_ids_per_batch,
metadata.decoder_num_blocks, forward_meta.decoder_num_blocks_cpu,
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu forward_meta.decoder_num_blocks_cpu,
metadata.max_enc_len_this_time, metadata.max_enc_len_this_time,
metadata.max_dec_len_this_time, metadata.max_dec_len_this_time,
metadata.max_len_kv, metadata.max_len_kv,

View File

@@ -28,6 +28,10 @@ def get_block_shape_and_split_kv_block(
seq_lens_encoder: paddle.Tensor, seq_lens_encoder: paddle.Tensor,
seq_lens_decoder: paddle.Tensor, seq_lens_decoder: paddle.Tensor,
seq_lens_this_time: paddle.Tensor, seq_lens_this_time: paddle.Tensor,
decoder_batch_ids: paddle.Tensor,
decoder_tile_ids_per_batch: paddle.Tensor,
decoder_num_blocks_x_cpu: paddle.Tensor,
max_len_tensor_cpu: paddle.Tensor,
encoder_block_shape_q: int, encoder_block_shape_q: int,
decoder_block_shape_q: int, decoder_block_shape_q: int,
group_size: int, group_size: int,
@@ -45,15 +49,15 @@ def get_block_shape_and_split_kv_block(
kv_batch_ids, kv_batch_ids,
kv_tile_ids_per_batch, kv_tile_ids_per_batch,
kv_num_blocks, kv_num_blocks,
decoder_batch_ids, max_len_kv_cpu,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_kv,
set_max_lengths,
) = get_block_shape_and_split_kv_block_cuda( ) = get_block_shape_and_split_kv_block_cuda(
seq_lens_encoder, seq_lens_encoder,
seq_lens_decoder, seq_lens_decoder,
seq_lens_this_time, seq_lens_this_time,
decoder_batch_ids,
decoder_tile_ids_per_batch,
decoder_num_blocks_x_cpu,
max_len_tensor_cpu,
encoder_block_shape_q, encoder_block_shape_q,
decoder_block_shape_q, decoder_block_shape_q,
group_size, group_size,
@@ -67,11 +71,7 @@ def get_block_shape_and_split_kv_block(
kv_batch_ids, kv_batch_ids,
kv_tile_ids_per_batch, kv_tile_ids_per_batch,
kv_num_blocks, kv_num_blocks,
decoder_batch_ids, max_len_kv_cpu,
decoder_tile_ids_per_batch,
decoder_num_blocks,
max_len_kv,
set_max_lengths,
) )
else: else:
raise NotImplementedError raise NotImplementedError

View File

@@ -44,26 +44,13 @@ class XPUAttentionMetadata(AttentionMetadata):
XPUAttentionMetadata XPUAttentionMetadata
""" """
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
_dtype: paddle.dtype = paddle.bfloat16 _dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768 encoder_max_partition_size: int = 32768
max_partition_size: int = 32768 max_partition_size: int = 32768
block_tables: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None
rotary_embs: Optional[paddle.Tensor] = None rotary_embs: Optional[paddle.Tensor] = None
attn_mask: Optional[paddle.Tensor] = None attn_mask: Optional[paddle.Tensor] = None
encoder_block_shape_q: int = -1
decoder_block_shape_q: int = -1
_fuse_kernel_compute_dtype: str = "bf16" _fuse_kernel_compute_dtype: str = "bf16"
# pd_disaggregation # pd_disaggregation
@@ -91,7 +78,6 @@ class XPUAttentionBackend(AttentionBackend):
""" """
super().__init__() super().__init__()
self.attention_metadata: XPUAttentionMetadata = None self.attention_metadata: XPUAttentionMetadata = None
# TODO(gongshaotian): Use fd_config parameters in the correct location
self.block_size: int = fd_config.cache_config.block_size self.block_size: int = fd_config.cache_config.block_size
self.max_seq_len: int = fd_config.parallel_config.max_model_len self.max_seq_len: int = fd_config.parallel_config.max_model_len
self.rope_theta: float = ( self.rope_theta: float = (
@@ -99,9 +85,6 @@ class XPUAttentionBackend(AttentionBackend):
) )
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False) self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
self.causal: bool = getattr(fd_config.model_config, "causal", True) self.causal: bool = getattr(fd_config.model_config, "causal", True)
# self.speculate_method = fd_config.parallel_config.speculate_method
# self.use_speculate = self.speculate_method is not None
# self.speculate_max_draft_token_num = fd_config.parallel_config.speculate_max_draft_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank self.rank: int = fd_config.parallel_config.tensor_parallel_rank
@@ -117,8 +100,6 @@ class XPUAttentionBackend(AttentionBackend):
def init_attention_metadata(self, forward_meta: ForwardMeta): def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it.""" """Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = XPUAttentionMetadata() metadata = XPUAttentionMetadata()
metadata.encoder_block_shape_q = 64
metadata.decoder_block_shape_q = 16
metadata.max_partition_size = 32768 metadata.max_partition_size = 32768
metadata.encoder_max_partition_size = 32768 metadata.encoder_max_partition_size = 32768
metadata._dtype = paddle.get_default_dtype() metadata._dtype = paddle.get_default_dtype()

View File

@@ -184,13 +184,26 @@ class MTPProposer(Proposer):
""" """
assert len(self.attn_backends) == 0 assert len(self.attn_backends) == 0
# TODO(gongshaotian): Get rank from config
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
self.model_config.kv_num_heads = ( self.model_config.kv_num_heads = max(
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size 1,
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size,
) )
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
# Initialize AttentionBackend buffers
encoder_block_shape_q = 64
decoder_block_shape_q = 16
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"])
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
self.main_model_inputs["decoder_tile_ids_per_batch"]
)
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
self.main_model_inputs["decoder_num_blocks_cpu"]
).pin_memory()
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu()
# Get the attention backend # Get the attention backend
attn_cls = get_attention_backend() attn_cls = get_attention_backend()
attn_backend = attn_cls( attn_backend = attn_cls(
@@ -198,6 +211,8 @@ class MTPProposer(Proposer):
kv_num_heads=self.model_config.kv_num_heads, kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads, num_heads=num_heads,
head_dim=head_dim, head_dim=head_dim,
encoder_block_shape_q=encoder_block_shape_q,
decoder_block_shape_q=decoder_block_shape_q,
) )
if attn_backend is None: if attn_backend is None:
raise NotImplementedError( raise NotImplementedError(
@@ -293,6 +308,12 @@ class MTPProposer(Proposer):
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"] self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
self.model_inputs["substep"] = 0 self.model_inputs["substep"] = 0
# Declare AttentionBackend buffers
self.model_inputs["decoder_batch_ids"] = None
self.model_inputs["decoder_tile_ids_per_batch"] = None
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.model_inputs["max_len_tensor_cpu"] = None # CPU
# Input tokens # Input tokens
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64") self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
@@ -405,6 +426,8 @@ class MTPProposer(Proposer):
attn_backend=self.attn_backends[0], attn_backend=self.attn_backends[0],
decoder_batch_ids=self.model_inputs["decoder_batch_ids"], decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"], decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"],
max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.model_inputs["seq_lens_encoder"], seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
seq_lens_decoder=self.model_inputs["seq_lens_decoder"], seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
seq_lens_this_time=self.model_inputs["seq_lens_this_time"], seq_lens_this_time=self.model_inputs["seq_lens_this_time"],

View File

@@ -417,9 +417,11 @@ class GCUModelRunner(ModelRunnerBase):
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
# AttentionBackend buffers # Declare AttentionBackend buffers
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["decoder_batch_ids"] = None
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["decoder_tile_ids_per_batch"] = None
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.share_inputs["max_len_tensor_cpu"] = None # CPU
# Initialize rotary position embedding # Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
@@ -579,6 +581,8 @@ class GCUModelRunner(ModelRunnerBase):
attn_backend=self.attn_backends[0], attn_backend=self.attn_backends[0],
decoder_batch_ids=self.share_inputs["decoder_batch_ids"], decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"], seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
@@ -655,6 +659,18 @@ class GCUModelRunner(ModelRunnerBase):
) )
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
# Initialize AttentionBackend buffers
encoder_block_shape_q = 64
decoder_block_shape_q = 16
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
)
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
# Get the attention backend # Get the attention backend
attn_cls = get_attention_backend() attn_cls = get_attention_backend()
attn_backend = attn_cls( attn_backend = attn_cls(
@@ -662,6 +678,8 @@ class GCUModelRunner(ModelRunnerBase):
kv_num_heads=self.model_config.kv_num_heads, kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads, num_heads=num_heads,
head_dim=head_dim, head_dim=head_dim,
encoder_block_shape_q=encoder_block_shape_q,
decoder_block_shape_q=decoder_block_shape_q,
) )
if attn_backend is None: if attn_backend is None:
raise NotImplementedError( raise NotImplementedError(
@@ -1179,9 +1197,5 @@ class GCUModelRunner(ModelRunnerBase):
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
""" """
# TODO(gongshaotian): Use more efficient implementation # In init_attention_metadata, the decode buffer has already been cleared
if self.forward_meta.step_use_cudagraph: return
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
for i in range(1, num_empty_batch + 1):
self.forward_meta.decoder_batch_ids[-i] = 0
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0

View File

@@ -610,9 +610,7 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64") self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
self.share_inputs["not_need_stop"] = paddle.full( self.share_inputs["not_need_stop"] = paddle.full([1], False, dtype="bool").cpu()
[1], False, dtype="bool"
).cpu() # TODO(gongshaotian): move to pinnd memory
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool") self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64") self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
@@ -643,9 +641,12 @@ class GPUModelRunner(ModelRunnerBase):
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
# AttentionBackend buffers
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") # Declare AttentionBackend buffers
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["decoder_batch_ids"] = None
self.share_inputs["decoder_tile_ids_per_batch"] = None
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
self.share_inputs["max_len_tensor_cpu"] = None # CPU
# Initialize rotary position embedding # Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
@@ -845,6 +846,8 @@ class GPUModelRunner(ModelRunnerBase):
attn_backend=self.attn_backends[0], attn_backend=self.attn_backends[0],
decoder_batch_ids=self.share_inputs["decoder_batch_ids"], decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"], decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
seq_lens_encoder=self.share_inputs["seq_lens_encoder"], seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
seq_lens_decoder=self.share_inputs["seq_lens_decoder"], seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
seq_lens_this_time=self.share_inputs["seq_lens_this_time"], seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
@@ -856,7 +859,6 @@ class GPUModelRunner(ModelRunnerBase):
) )
# Update Batch type for cuda graph # Update Batch type for cuda graph
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
only_decode_batch = True only_decode_batch = True
prefill_exists = None prefill_exists = None
# mix ep in single node # mix ep in single node
@@ -946,6 +948,18 @@ class GPUModelRunner(ModelRunnerBase):
) )
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
# Initialize AttentionBackend buffers
encoder_block_shape_q = 64
decoder_block_shape_q = 16
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
)
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
# Get the attention backend # Get the attention backend
attn_cls = get_attention_backend() attn_cls = get_attention_backend()
attn_backend = attn_cls( attn_backend = attn_cls(
@@ -953,6 +967,8 @@ class GPUModelRunner(ModelRunnerBase):
kv_num_heads=self.model_config.kv_num_heads, kv_num_heads=self.model_config.kv_num_heads,
num_heads=num_heads, num_heads=num_heads,
head_dim=head_dim, head_dim=head_dim,
encoder_block_shape_q=encoder_block_shape_q,
decoder_block_shape_q=decoder_block_shape_q,
) )
self.attn_backends.append(attn_backend) self.attn_backends.append(attn_backend)
@@ -1527,12 +1543,8 @@ class GPUModelRunner(ModelRunnerBase):
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch. Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch. In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
""" """
# TODO(gongshaotian): Use more efficient implementation # In init_attention_metadata, the decode buffer has already been cleared
if self.forward_meta.step_use_cudagraph: return
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
for i in range(1, num_empty_batch + 1):
self.forward_meta.decoder_batch_ids[-i] = 0
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
def _init_image_preprocess(self) -> None: def _init_image_preprocess(self) -> None:
processor = DataProcessor( processor = DataProcessor(

View File

@@ -383,8 +383,8 @@ class IluvatarModelRunner(ModelRunnerBase):
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
# AttentionBackend buffers # AttentionBackend buffers
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["decoder_batch_ids"] = None
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32") self.share_inputs["decoder_tile_ids_per_batch"] = None
# Initialize rotary position embedding # Initialize rotary position embedding
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1)) tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))