mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-04 08:16:42 +08:00
[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)
* reset decoder_block_shape_q buffer * refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch * update decode_max_tile_size * fix pre-commit * update block_multihead_attn_backend * update flas attn backend * update MLA Attention * update XPU Attention * update gcu,iluvatar model runner * Update MTP * fix MTP bug
This commit is contained in:
@@ -195,22 +195,25 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int group_size, const int block_size,
|
||||
const int decoder_step_token_num) {
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num)
|
||||
{
|
||||
auto stream = seq_lens_encoder.stream();
|
||||
int bsz = seq_lens_this_time.shape()[0];
|
||||
auto max_len_tensor =
|
||||
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||
max_len_tensor, bsz);
|
||||
|
||||
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
|
||||
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
|
||||
// max_just_dec_merged_len_this_time, max_system_len,
|
||||
// max_just_dec_len_without_system
|
||||
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
|
||||
auto max_len_cpu_ptr = max_len_cpu.data<int>();
|
||||
paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
|
||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||
max_len_tensor_gpu, bsz);
|
||||
max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
|
||||
|
||||
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
|
||||
int max_len_this_time = max_len_cpu_ptr[0];
|
||||
int max_enc_len_this_time = max_len_cpu_ptr[1];
|
||||
int max_dec_len_this_time = max_len_cpu_ptr[2];
|
||||
@@ -222,14 +225,11 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
|
||||
paddle::Tensor encoder_batch_ids;
|
||||
paddle::Tensor encoder_tile_ids_per_batch;
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor kv_batch_ids;
|
||||
paddle::Tensor kv_tile_ids_per_batch;
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor decoder_batch_ids;
|
||||
paddle::Tensor decoder_tile_ids_per_batch;
|
||||
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||
|
||||
auto max_len_kv =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||
@@ -291,92 +291,64 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
kv_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
}
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
|
||||
decoder_batch_ids =
|
||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
if (max_just_dec_len_this_time > 0) {
|
||||
// Clear buffer
|
||||
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
|
||||
|
||||
auto decoder_num_blocks_x =
|
||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
split_q_block<<<1, 32, 0, stream>>>(
|
||||
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
|
||||
seq_lens_this_time.data<int>(),
|
||||
seq_lens_encoder.data<int>(),
|
||||
decoder_batch_ids.data<int>(),
|
||||
decoder_tile_ids_per_batch.data<int>(),
|
||||
decoder_num_blocks_x.data<int>(),
|
||||
bsz,
|
||||
decoder_block_shape_q,
|
||||
group_size);
|
||||
decoder_num_blocks_x_cpu =
|
||||
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
||||
} else {
|
||||
decoder_batch_ids =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_tile_ids_per_batch =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||
decoder_num_blocks_x_cpu =
|
||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
||||
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||
}
|
||||
|
||||
return {encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu /*cpu*/,
|
||||
max_len_cpu};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
||||
const paddle::DataType &seq_lens_encoder_dtype,
|
||||
const paddle::DataType &seq_lens_decoder_dtype,
|
||||
const paddle::DataType &seq_lens_this_time_dtype) {
|
||||
return {
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
||||
paddle::DataType::INT32, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
||||
const std::vector<int64_t> &seq_lens_this_time_shape) {
|
||||
std::vector<int64_t> dynamic_shape = {-1};
|
||||
|
||||
return {dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
dynamic_shape,
|
||||
dynamic_shape,
|
||||
{1},
|
||||
{1},
|
||||
{8}};
|
||||
encoder_batch_ids,
|
||||
encoder_tile_ids_per_batch,
|
||||
encoder_num_blocks_x_cpu, /*cpu*/
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks_x_cpu, /*cpu*/
|
||||
max_len_kv_cpu, /*cpu*/
|
||||
};
|
||||
}
|
||||
|
||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
|
||||
.Outputs({paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks"),
|
||||
paddle::Optional("decoder_batch_ids"),
|
||||
paddle::Optional("decoder_tile_ids_per_batch"),
|
||||
paddle::Optional("decoder_num_blocks"),
|
||||
paddle::Optional("max_len_kv"), "set_max_lengths"})
|
||||
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
|
||||
"group_size: int", "block_size: int",
|
||||
"decoder_step_token_num: int"})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
|
||||
.Inputs({
|
||||
"seq_lens_encoder",
|
||||
"seq_lens_decoder",
|
||||
"seq_lens_this_time",
|
||||
"decoder_batch_ids",
|
||||
"decoder_tile_ids_per_batch",
|
||||
"decoder_num_blocks_x_cpu",
|
||||
"max_len_tensor_cpu"
|
||||
})
|
||||
.Outputs({
|
||||
paddle::Optional("encoder_batch_ids"),
|
||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||
paddle::Optional("encoder_num_blocks_x_cpu"),
|
||||
paddle::Optional("kv_batch_ids"),
|
||||
paddle::Optional("kv_tile_ids_per_batch"),
|
||||
paddle::Optional("kv_num_blocks_x_cpu"),
|
||||
"max_len_kv_cpu"
|
||||
})
|
||||
.Attrs({
|
||||
"encoder_block_shape_q: int",
|
||||
"decoder_block_shape_q: int",
|
||||
"group_size: int",
|
||||
"block_size: int",
|
||||
"decoder_step_token_num: int"
|
||||
})
|
||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));
|
||||
|
@@ -235,8 +235,14 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
||||
const paddle::Tensor &seq_lens_encoder,
|
||||
const paddle::Tensor &seq_lens_decoder,
|
||||
const paddle::Tensor &seq_lens_this_time,
|
||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
||||
const int group_size, const int block_size,
|
||||
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||
const int encoder_block_shape_q,
|
||||
const int decoder_block_shape_q,
|
||||
const int group_size,
|
||||
const int block_size,
|
||||
const int decoder_step_token_num);
|
||||
|
||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||
|
@@ -77,6 +77,10 @@ class ForwardMeta:
|
||||
decoder_batch_ids: Optional[paddle.Tensor] = None
|
||||
# Tile ID for each batch of the decoder. Used by attention backend.
|
||||
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||
# The number of blocks that attention backend can use in decode stage
|
||||
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
|
||||
# Recorded multiple lengths related to prefill or decode
|
||||
max_len_tensor_cpu: Optional[paddle.Tensor] = None
|
||||
|
||||
# Sequence length of encoder for ever batch
|
||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||
|
@@ -48,17 +48,13 @@ class AppendAttentionMetadata(AttentionMetadata):
|
||||
AppendAttentionMetadata
|
||||
"""
|
||||
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
max_len_kv: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
@@ -66,8 +62,6 @@ class AppendAttentionMetadata(AttentionMetadata):
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
@@ -89,6 +83,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
encoder_block_shape_q: int = -1,
|
||||
decoder_block_shape_q: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
AppendAttentionBackend __init__
|
||||
@@ -110,9 +106,12 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||
|
||||
@@ -126,8 +125,6 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = AppendAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.max_partition_size = self.max_partition_size
|
||||
metadata.encoder_max_partition_size = self.max_seq_len
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
@@ -148,18 +145,18 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids, # will copy to buffer
|
||||
metadata.decoder_tile_ids_per_batch, # will copy to buffer
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
@@ -181,8 +178,6 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
@@ -249,10 +244,10 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.decoder_batch_ids, # from buffer
|
||||
forward_meta.decoder_tile_ids_per_batch, # from buffer
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.set_max_lengths,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
metadata.max_len_kv,
|
||||
metadata.rotary_embs,
|
||||
metadata.attn_mask,
|
||||
@@ -275,8 +270,8 @@ class AppendAttentionBackend(AttentionBackend):
|
||||
getattr(layer, "quant_max_bound", 0.0),
|
||||
getattr(layer, "quant_min_bound", 0.0),
|
||||
getattr(layer, "out_scale", -1.0),
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
metadata.max_partition_size,
|
||||
metadata.encoder_max_partition_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
|
@@ -38,17 +38,12 @@ class BlockAttentionMetadata(AttentionMetadata):
|
||||
BlockAttentionMetadata
|
||||
"""
|
||||
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
@@ -56,8 +51,6 @@ class BlockAttentionMetadata(AttentionMetadata):
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
@@ -79,6 +72,8 @@ class BlockAttentionBackend(AttentionBackend):
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
encoder_block_shape_q: int = -1,
|
||||
decoder_block_shape_q: int = -1,
|
||||
):
|
||||
"""
|
||||
BlockAttentionBackend __init__
|
||||
|
@@ -53,8 +53,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
FlashAttentionMetadata
|
||||
"""
|
||||
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
@@ -63,12 +61,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
|
||||
cu_seqlens_q: paddle.Tensor = None
|
||||
cu_seqlens_k: paddle.Tensor = None
|
||||
@@ -100,6 +92,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
encoder_block_shape_q: int = -1,
|
||||
decoder_block_shape_q: int = -1,
|
||||
):
|
||||
"""
|
||||
FlashAttentionBackend __init__
|
||||
@@ -111,10 +105,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self.kv_num_heads = kv_num_heads
|
||||
self.num_heads = num_heads
|
||||
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||
self.head_dim = fd_config.model_config.head_dim
|
||||
self.attn_outputsize_tp = self.num_heads * self.head_dim
|
||||
self.block_size = fd_config.cache_config.block_size
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
self.speculative_method = fd_config.speculative_config.method
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
@@ -176,8 +173,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
metadata = FlashAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||
metadata.rotary_embs = forward_meta.rotary_embs
|
||||
metadata.block_tables = forward_meta.block_tables
|
||||
@@ -188,18 +183,18 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
@@ -233,8 +228,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.rank, int(self.device_id), self.keep_pd_step_flag
|
||||
)
|
||||
self.attention_metadata = metadata
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def forward_mixed(
|
||||
self,
|
||||
@@ -291,8 +284,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
v,
|
||||
metadata.cu_seqlens_q,
|
||||
metadata.cu_seqlens_k,
|
||||
max_seqlen_q=metadata.set_max_lengths[0],
|
||||
max_seqlen_k=metadata.set_max_lengths[3],
|
||||
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
|
||||
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
|
||||
causal=self.causal,
|
||||
**self.flash_attn_kwargs,
|
||||
)[0].reshape([-1, self.attn_outputsize_tp])
|
||||
|
@@ -64,17 +64,13 @@ class MLAAttentionMetadata(AttentionMetadata):
|
||||
MLAAttentionMetadata for Multi-Layer Attention
|
||||
"""
|
||||
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
max_len_kv: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
@@ -82,8 +78,6 @@ class MLAAttentionMetadata(AttentionMetadata):
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
@@ -105,6 +99,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
kv_num_heads: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
encoder_block_shape_q: int = -1,
|
||||
decoder_block_shape_q: int = -1,
|
||||
) -> None:
|
||||
"""
|
||||
MLAAttentionBackend __init__
|
||||
@@ -128,8 +124,11 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
self.kv_num_heads: int = kv_num_heads
|
||||
self.num_heads: int = num_heads
|
||||
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||
self.head_dim: int = fd_config.model_config.head_dim
|
||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||
|
||||
# For Multi Head Latent Attention
|
||||
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
|
||||
@@ -152,8 +151,6 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = MLAAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.max_partition_size = 32768
|
||||
metadata.encoder_max_partition_size = self.max_seq_len
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
@@ -176,27 +173,25 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
metadata.kv_batch_ids,
|
||||
metadata.kv_tile_ids_per_batch,
|
||||
metadata.kv_num_blocks,
|
||||
metadata.decoder_batch_ids,
|
||||
metadata.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.max_len_kv,
|
||||
metadata.set_max_lengths,
|
||||
) = get_block_shape_and_split_kv_block(
|
||||
forward_meta.seq_lens_encoder,
|
||||
forward_meta.seq_lens_decoder,
|
||||
forward_meta.seq_lens_this_time,
|
||||
metadata.encoder_block_shape_q,
|
||||
metadata.decoder_block_shape_q,
|
||||
self.num_heads // self.kv_num_heads,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.max_len_tensor_cpu,
|
||||
self.encoder_block_shape_q,
|
||||
self.decoder_block_shape_q,
|
||||
self.group_size,
|
||||
self.block_size,
|
||||
self.speculate_max_draft_token_num + 1,
|
||||
)
|
||||
|
||||
# MLA
|
||||
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||
forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1]
|
||||
forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2]
|
||||
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
|
||||
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
|
||||
|
||||
# pd_disaggregation
|
||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||
@@ -216,9 +211,6 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
|
||||
self.attention_metadata: AttentionMetadata = metadata
|
||||
|
||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
||||
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
@@ -354,8 +346,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
@@ -476,8 +468,8 @@ class MLAAttentionBackend(AttentionBackend):
|
||||
metadata.kv_num_blocks,
|
||||
forward_meta.decoder_batch_ids,
|
||||
forward_meta.decoder_tile_ids_per_batch,
|
||||
metadata.decoder_num_blocks,
|
||||
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
forward_meta.decoder_num_blocks_cpu,
|
||||
metadata.max_enc_len_this_time,
|
||||
metadata.max_dec_len_this_time,
|
||||
metadata.max_len_kv,
|
||||
|
@@ -28,6 +28,10 @@ def get_block_shape_and_split_kv_block(
|
||||
seq_lens_encoder: paddle.Tensor,
|
||||
seq_lens_decoder: paddle.Tensor,
|
||||
seq_lens_this_time: paddle.Tensor,
|
||||
decoder_batch_ids: paddle.Tensor,
|
||||
decoder_tile_ids_per_batch: paddle.Tensor,
|
||||
decoder_num_blocks_x_cpu: paddle.Tensor,
|
||||
max_len_tensor_cpu: paddle.Tensor,
|
||||
encoder_block_shape_q: int,
|
||||
decoder_block_shape_q: int,
|
||||
group_size: int,
|
||||
@@ -45,15 +49,15 @@ def get_block_shape_and_split_kv_block(
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
max_len_kv,
|
||||
set_max_lengths,
|
||||
max_len_kv_cpu,
|
||||
) = get_block_shape_and_split_kv_block_cuda(
|
||||
seq_lens_encoder,
|
||||
seq_lens_decoder,
|
||||
seq_lens_this_time,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks_x_cpu,
|
||||
max_len_tensor_cpu,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
group_size,
|
||||
@@ -67,11 +71,7 @@ def get_block_shape_and_split_kv_block(
|
||||
kv_batch_ids,
|
||||
kv_tile_ids_per_batch,
|
||||
kv_num_blocks,
|
||||
decoder_batch_ids,
|
||||
decoder_tile_ids_per_batch,
|
||||
decoder_num_blocks,
|
||||
max_len_kv,
|
||||
set_max_lengths,
|
||||
max_len_kv_cpu,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
@@ -44,26 +44,13 @@ class XPUAttentionMetadata(AttentionMetadata):
|
||||
XPUAttentionMetadata
|
||||
"""
|
||||
|
||||
max_len_kv: paddle.Tensor = None
|
||||
set_max_lengths: int = -1
|
||||
encoder_batch_ids: paddle.Tensor = None
|
||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
encoder_num_blocks: paddle.Tensor = None
|
||||
kv_batch_ids: paddle.Tensor = None
|
||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||
kv_num_blocks: paddle.Tensor = None
|
||||
decoder_batch_ids: paddle.Tensor = None
|
||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
||||
decoder_num_blocks: paddle.Tensor = None
|
||||
|
||||
_dtype: paddle.dtype = paddle.bfloat16
|
||||
encoder_max_partition_size: int = 32768
|
||||
max_partition_size: int = 32768
|
||||
block_tables: Optional[paddle.Tensor] = None
|
||||
rotary_embs: Optional[paddle.Tensor] = None
|
||||
attn_mask: Optional[paddle.Tensor] = None
|
||||
encoder_block_shape_q: int = -1
|
||||
decoder_block_shape_q: int = -1
|
||||
|
||||
_fuse_kernel_compute_dtype: str = "bf16"
|
||||
|
||||
# pd_disaggregation
|
||||
@@ -91,7 +78,6 @@ class XPUAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
super().__init__()
|
||||
self.attention_metadata: XPUAttentionMetadata = None
|
||||
# TODO(gongshaotian): Use fd_config parameters in the correct location
|
||||
self.block_size: int = fd_config.cache_config.block_size
|
||||
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
||||
self.rope_theta: float = (
|
||||
@@ -99,9 +85,6 @@ class XPUAttentionBackend(AttentionBackend):
|
||||
)
|
||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||
# self.speculate_method = fd_config.parallel_config.speculate_method
|
||||
# self.use_speculate = self.speculate_method is not None
|
||||
# self.speculate_max_draft_token_num = fd_config.parallel_config.speculate_max_draft_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
@@ -117,8 +100,6 @@ class XPUAttentionBackend(AttentionBackend):
|
||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||
metadata = XPUAttentionMetadata()
|
||||
metadata.encoder_block_shape_q = 64
|
||||
metadata.decoder_block_shape_q = 16
|
||||
metadata.max_partition_size = 32768
|
||||
metadata.encoder_max_partition_size = 32768
|
||||
metadata._dtype = paddle.get_default_dtype()
|
||||
|
@@ -184,13 +184,26 @@ class MTPProposer(Proposer):
|
||||
"""
|
||||
assert len(self.attn_backends) == 0
|
||||
|
||||
# TODO(gongshaotian): Get rank from config
|
||||
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
|
||||
self.model_config.kv_num_heads = (
|
||||
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size
|
||||
self.model_config.kv_num_heads = max(
|
||||
1,
|
||||
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size,
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
# Initialize AttentionBackend buffers
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
|
||||
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"])
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
|
||||
self.main_model_inputs["decoder_tile_ids_per_batch"]
|
||||
)
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||
self.main_model_inputs["decoder_num_blocks_cpu"]
|
||||
).pin_memory()
|
||||
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
@@ -198,6 +211,8 @@ class MTPProposer(Proposer):
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
encoder_block_shape_q=encoder_block_shape_q,
|
||||
decoder_block_shape_q=decoder_block_shape_q,
|
||||
)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
@@ -293,6 +308,12 @@ class MTPProposer(Proposer):
|
||||
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
|
||||
self.model_inputs["substep"] = 0
|
||||
|
||||
# Declare AttentionBackend buffers
|
||||
self.model_inputs["decoder_batch_ids"] = None
|
||||
self.model_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.model_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
|
||||
# Input tokens
|
||||
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
|
||||
|
||||
@@ -405,6 +426,8 @@ class MTPProposer(Proposer):
|
||||
attn_backend=self.attn_backends[0],
|
||||
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
|
||||
decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"],
|
||||
max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"],
|
||||
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.model_inputs["seq_lens_this_time"],
|
||||
|
@@ -417,9 +417,11 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
# AttentionBackend buffers
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
# Declare AttentionBackend buffers
|
||||
self.share_inputs["decoder_batch_ids"] = None
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
@@ -579,6 +581,8 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
attn_backend=self.attn_backends[0],
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
|
||||
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||
@@ -655,6 +659,18 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
# Initialize AttentionBackend buffers
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
||||
)
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
@@ -662,6 +678,8 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
encoder_block_shape_q=encoder_block_shape_q,
|
||||
decoder_block_shape_q=decoder_block_shape_q,
|
||||
)
|
||||
if attn_backend is None:
|
||||
raise NotImplementedError(
|
||||
@@ -1179,9 +1197,5 @@ class GCUModelRunner(ModelRunnerBase):
|
||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
||||
"""
|
||||
# TODO(gongshaotian): Use more efficient implementation
|
||||
if self.forward_meta.step_use_cudagraph:
|
||||
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
|
||||
for i in range(1, num_empty_batch + 1):
|
||||
self.forward_meta.decoder_batch_ids[-i] = 0
|
||||
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
|
||||
# In init_attention_metadata, the decode buffer has already been cleared
|
||||
return
|
||||
|
@@ -610,9 +610,7 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||
self.share_inputs["not_need_stop"] = paddle.full(
|
||||
[1], False, dtype="bool"
|
||||
).cpu() # TODO(gongshaotian): move to pinnd memory
|
||||
self.share_inputs["not_need_stop"] = paddle.full([1], False, dtype="bool").cpu()
|
||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
|
||||
|
||||
@@ -643,9 +641,12 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
# AttentionBackend buffers
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
|
||||
# Declare AttentionBackend buffers
|
||||
self.share_inputs["decoder_batch_ids"] = None
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
@@ -845,6 +846,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
attn_backend=self.attn_backends[0],
|
||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
|
||||
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
|
||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||
@@ -856,7 +859,6 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
|
||||
# Update Batch type for cuda graph
|
||||
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
|
||||
only_decode_batch = True
|
||||
prefill_exists = None
|
||||
# mix ep in single node
|
||||
@@ -946,6 +948,18 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
)
|
||||
head_dim = self.model_config.head_dim
|
||||
|
||||
# Initialize AttentionBackend buffers
|
||||
encoder_block_shape_q = 64
|
||||
decoder_block_shape_q = 16
|
||||
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
||||
)
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||
|
||||
# Get the attention backend
|
||||
attn_cls = get_attention_backend()
|
||||
attn_backend = attn_cls(
|
||||
@@ -953,6 +967,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
kv_num_heads=self.model_config.kv_num_heads,
|
||||
num_heads=num_heads,
|
||||
head_dim=head_dim,
|
||||
encoder_block_shape_q=encoder_block_shape_q,
|
||||
decoder_block_shape_q=decoder_block_shape_q,
|
||||
)
|
||||
|
||||
self.attn_backends.append(attn_backend)
|
||||
@@ -1527,12 +1543,8 @@ class GPUModelRunner(ModelRunnerBase):
|
||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
||||
"""
|
||||
# TODO(gongshaotian): Use more efficient implementation
|
||||
if self.forward_meta.step_use_cudagraph:
|
||||
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
|
||||
for i in range(1, num_empty_batch + 1):
|
||||
self.forward_meta.decoder_batch_ids[-i] = 0
|
||||
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
|
||||
# In init_attention_metadata, the decode buffer has already been cleared
|
||||
return
|
||||
|
||||
def _init_image_preprocess(self) -> None:
|
||||
processor = DataProcessor(
|
||||
|
@@ -383,8 +383,8 @@ class IluvatarModelRunner(ModelRunnerBase):
|
||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
# AttentionBackend buffers
|
||||
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||
self.share_inputs["decoder_batch_ids"] = None
|
||||
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||
|
||||
# Initialize rotary position embedding
|
||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||
|
Reference in New Issue
Block a user