mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-05 08:37:06 +08:00
[Executor] Refactor GetBlockShapeAndSplitKVBlock Kernel (#2989)
* reset decoder_block_shape_q buffer * refactor GetBlockShapeAndSplitKVBlock Kernel and cudagraph padding batch * update decode_max_tile_size * fix pre-commit * update block_multihead_attn_backend * update flas attn backend * update MLA Attention * update XPU Attention * update gcu,iluvatar model runner * Update MTP * fix MTP bug
This commit is contained in:
@@ -195,22 +195,25 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
const paddle::Tensor &seq_lens_encoder,
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||||
const int group_size, const int block_size,
|
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||||
const int decoder_step_token_num) {
|
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||||
|
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||||
|
const int encoder_block_shape_q,
|
||||||
|
const int decoder_block_shape_q,
|
||||||
|
const int group_size,
|
||||||
|
const int block_size,
|
||||||
|
const int decoder_step_token_num)
|
||||||
|
{
|
||||||
auto stream = seq_lens_encoder.stream();
|
auto stream = seq_lens_encoder.stream();
|
||||||
int bsz = seq_lens_this_time.shape()[0];
|
int bsz = seq_lens_this_time.shape()[0];
|
||||||
auto max_len_tensor =
|
|
||||||
GetEmptyTensor({8}, paddle::DataType::INT32, seq_lens_encoder.place());
|
|
||||||
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
|
||||||
max_len_tensor, bsz);
|
|
||||||
|
|
||||||
// max_len_this_time, max_enc_len_this_time, max_dec_len_this_time,
|
paddle::Tensor max_len_tensor_gpu = GetEmptyTensor({max_len_tensor_cpu.shape()[0]}, paddle::DataType::INT32, seq_lens_this_time.place());
|
||||||
// max_enc_dec_len_this_time, max_just_dec_len_this_time,
|
GetMaxLen(seq_lens_decoder, seq_lens_this_time, seq_lens_encoder,
|
||||||
// max_just_dec_merged_len_this_time, max_system_len,
|
max_len_tensor_gpu, bsz);
|
||||||
// max_just_dec_len_without_system
|
max_len_tensor_cpu.copy_(max_len_tensor_gpu, max_len_tensor_cpu.place(), false);
|
||||||
auto max_len_cpu = max_len_tensor.copy_to(paddle::CPUPlace(), false);
|
|
||||||
auto max_len_cpu_ptr = max_len_cpu.data<int>();
|
auto max_len_cpu_ptr = max_len_tensor_cpu.data<int>();
|
||||||
int max_len_this_time = max_len_cpu_ptr[0];
|
int max_len_this_time = max_len_cpu_ptr[0];
|
||||||
int max_enc_len_this_time = max_len_cpu_ptr[1];
|
int max_enc_len_this_time = max_len_cpu_ptr[1];
|
||||||
int max_dec_len_this_time = max_len_cpu_ptr[2];
|
int max_dec_len_this_time = max_len_cpu_ptr[2];
|
||||||
@@ -222,14 +225,11 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
|
|
||||||
paddle::Tensor encoder_batch_ids;
|
paddle::Tensor encoder_batch_ids;
|
||||||
paddle::Tensor encoder_tile_ids_per_batch;
|
paddle::Tensor encoder_tile_ids_per_batch;
|
||||||
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
paddle::Tensor encoder_num_blocks_x_cpu; /*cpu*/
|
||||||
paddle::Tensor kv_batch_ids;
|
paddle::Tensor kv_batch_ids;
|
||||||
paddle::Tensor kv_tile_ids_per_batch;
|
paddle::Tensor kv_tile_ids_per_batch;
|
||||||
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
paddle::Tensor kv_num_blocks_x_cpu; /*cpu*/
|
||||||
paddle::Tensor decoder_batch_ids;
|
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
||||||
paddle::Tensor decoder_tile_ids_per_batch;
|
|
||||||
paddle::Tensor decoder_num_blocks_x_cpu; /*cpu*/
|
|
||||||
paddle::Tensor max_len_kv_cpu; /*cpu*/
|
|
||||||
|
|
||||||
auto max_len_kv =
|
auto max_len_kv =
|
||||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_decoder.place());
|
||||||
@@ -291,92 +291,64 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
kv_num_blocks_x_cpu =
|
kv_num_blocks_x_cpu =
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
}
|
}
|
||||||
if (max_just_dec_len_this_time > 0) {
|
|
||||||
const uint32_t decoder_max_tile_size_per_bs_q =
|
|
||||||
div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
|
||||||
|
|
||||||
decoder_batch_ids =
|
if (max_just_dec_len_this_time > 0) {
|
||||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
// Clear buffer
|
||||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
const uint32_t decoder_max_tile_size_per_bs_q = div_up((decoder_step_token_num * group_size), decoder_block_shape_q);
|
||||||
decoder_tile_ids_per_batch =
|
const uint32_t decoder_batch_shape = bsz * decoder_max_tile_size_per_bs_q;
|
||||||
GetEmptyTensor({bsz * decoder_max_tile_size_per_bs_q},
|
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_batch_ids.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||||
paddle::DataType::INT32, seq_lens_encoder.place());
|
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_tile_ids_per_batch.data<int>(), 0, decoder_batch_shape * sizeof(int32_t), stream));
|
||||||
|
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemsetAsync(decoder_num_blocks_x_cpu.data<int>(), 0, sizeof(int32_t), stream));
|
||||||
|
|
||||||
auto decoder_num_blocks_x =
|
auto decoder_num_blocks_x =
|
||||||
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
GetEmptyTensor({1}, paddle::DataType::INT32, seq_lens_encoder.place());
|
||||||
split_q_block<<<1, 32, 0, stream>>>(
|
split_q_block<<<1, 32, 0, stream>>>(
|
||||||
seq_lens_this_time.data<int>(), seq_lens_encoder.data<int>(),
|
seq_lens_this_time.data<int>(),
|
||||||
decoder_batch_ids.data<int>(), decoder_tile_ids_per_batch.data<int>(),
|
seq_lens_encoder.data<int>(),
|
||||||
decoder_num_blocks_x.data<int>(), bsz, decoder_block_shape_q,
|
decoder_batch_ids.data<int>(),
|
||||||
|
decoder_tile_ids_per_batch.data<int>(),
|
||||||
|
decoder_num_blocks_x.data<int>(),
|
||||||
|
bsz,
|
||||||
|
decoder_block_shape_q,
|
||||||
group_size);
|
group_size);
|
||||||
decoder_num_blocks_x_cpu =
|
decoder_num_blocks_x_cpu.copy_(decoder_num_blocks_x, decoder_num_blocks_x_cpu.place(), false);
|
||||||
decoder_num_blocks_x.copy_to(paddle::CPUPlace(), false);
|
|
||||||
} else {
|
|
||||||
decoder_batch_ids =
|
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
|
||||||
decoder_tile_ids_per_batch =
|
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, seq_lens_encoder.place());
|
|
||||||
decoder_num_blocks_x_cpu =
|
|
||||||
GetEmptyTensor({0}, paddle::DataType::INT32, paddle::CPUPlace());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return {encoder_batch_ids,
|
|
||||||
encoder_tile_ids_per_batch,
|
|
||||||
encoder_num_blocks_x_cpu, /*cpu*/
|
|
||||||
kv_batch_ids,
|
|
||||||
kv_tile_ids_per_batch,
|
|
||||||
kv_num_blocks_x_cpu, /*cpu*/
|
|
||||||
decoder_batch_ids,
|
|
||||||
decoder_tile_ids_per_batch,
|
|
||||||
decoder_num_blocks_x_cpu, /*cpu*/
|
|
||||||
max_len_kv_cpu /*cpu*/,
|
|
||||||
max_len_cpu};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<paddle::DataType> GetBlockShapeAndSplitKVBlockInferDtype(
|
|
||||||
const paddle::DataType &seq_lens_encoder_dtype,
|
|
||||||
const paddle::DataType &seq_lens_decoder_dtype,
|
|
||||||
const paddle::DataType &seq_lens_this_time_dtype) {
|
|
||||||
return {
|
return {
|
||||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
encoder_batch_ids,
|
||||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
encoder_tile_ids_per_batch,
|
||||||
paddle::DataType::INT32, paddle::DataType::INT32, paddle::DataType::INT32,
|
encoder_num_blocks_x_cpu, /*cpu*/
|
||||||
paddle::DataType::INT32, paddle::DataType::INT32};
|
kv_batch_ids,
|
||||||
}
|
kv_tile_ids_per_batch,
|
||||||
|
kv_num_blocks_x_cpu, /*cpu*/
|
||||||
std::vector<std::vector<int64_t>> GetBlockShapeAndSplitKVBlockInferShape(
|
max_len_kv_cpu, /*cpu*/
|
||||||
const std::vector<int64_t> &seq_lens_encoder_shape,
|
};
|
||||||
const std::vector<int64_t> &seq_lens_decoder_shape,
|
|
||||||
const std::vector<int64_t> &seq_lens_this_time_shape) {
|
|
||||||
std::vector<int64_t> dynamic_shape = {-1};
|
|
||||||
|
|
||||||
return {dynamic_shape,
|
|
||||||
dynamic_shape,
|
|
||||||
{1},
|
|
||||||
dynamic_shape,
|
|
||||||
dynamic_shape,
|
|
||||||
{1},
|
|
||||||
dynamic_shape,
|
|
||||||
dynamic_shape,
|
|
||||||
{1},
|
|
||||||
{1},
|
|
||||||
{8}};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
PD_BUILD_STATIC_OP(get_block_shape_and_split_kv_block)
|
||||||
.Inputs({"seq_lens_encoder", "seq_lens_decoder", "seq_lens_this_time"})
|
.Inputs({
|
||||||
.Outputs({paddle::Optional("encoder_batch_ids"),
|
"seq_lens_encoder",
|
||||||
paddle::Optional("encoder_tile_ids_per_batch"),
|
"seq_lens_decoder",
|
||||||
paddle::Optional("encoder_num_blocks"),
|
"seq_lens_this_time",
|
||||||
paddle::Optional("kv_batch_ids"),
|
"decoder_batch_ids",
|
||||||
paddle::Optional("kv_tile_ids_per_batch"),
|
"decoder_tile_ids_per_batch",
|
||||||
paddle::Optional("kv_num_blocks"),
|
"decoder_num_blocks_x_cpu",
|
||||||
paddle::Optional("decoder_batch_ids"),
|
"max_len_tensor_cpu"
|
||||||
paddle::Optional("decoder_tile_ids_per_batch"),
|
})
|
||||||
paddle::Optional("decoder_num_blocks"),
|
.Outputs({
|
||||||
paddle::Optional("max_len_kv"), "set_max_lengths"})
|
paddle::Optional("encoder_batch_ids"),
|
||||||
.Attrs({"encoder_block_shape_q: int", "decoder_block_shape_q: int",
|
paddle::Optional("encoder_tile_ids_per_batch"),
|
||||||
"group_size: int", "block_size: int",
|
paddle::Optional("encoder_num_blocks_x_cpu"),
|
||||||
"decoder_step_token_num: int"})
|
paddle::Optional("kv_batch_ids"),
|
||||||
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock))
|
paddle::Optional("kv_tile_ids_per_batch"),
|
||||||
.SetInferShapeFn(PD_INFER_SHAPE(GetBlockShapeAndSplitKVBlockInferShape))
|
paddle::Optional("kv_num_blocks_x_cpu"),
|
||||||
.SetInferDtypeFn(PD_INFER_DTYPE(GetBlockShapeAndSplitKVBlockInferDtype));
|
"max_len_kv_cpu"
|
||||||
|
})
|
||||||
|
.Attrs({
|
||||||
|
"encoder_block_shape_q: int",
|
||||||
|
"decoder_block_shape_q: int",
|
||||||
|
"group_size: int",
|
||||||
|
"block_size: int",
|
||||||
|
"decoder_step_token_num: int"
|
||||||
|
})
|
||||||
|
.SetKernelFn(PD_KERNEL(GetBlockShapeAndSplitKVBlock));
|
||||||
|
@@ -235,8 +235,14 @@ std::vector<paddle::Tensor> GetBlockShapeAndSplitKVBlock(
|
|||||||
const paddle::Tensor &seq_lens_encoder,
|
const paddle::Tensor &seq_lens_encoder,
|
||||||
const paddle::Tensor &seq_lens_decoder,
|
const paddle::Tensor &seq_lens_decoder,
|
||||||
const paddle::Tensor &seq_lens_this_time,
|
const paddle::Tensor &seq_lens_this_time,
|
||||||
const int encoder_block_shape_q, const int decoder_block_shape_q,
|
paddle::Tensor &decoder_batch_ids, // Inplace
|
||||||
const int group_size, const int block_size,
|
paddle::Tensor &decoder_tile_ids_per_batch, // Inplace
|
||||||
|
paddle::Tensor &decoder_num_blocks_x_cpu, // Inplace, Pinned Memory
|
||||||
|
paddle::Tensor &max_len_tensor_cpu, // Inplace, Pinned Memory
|
||||||
|
const int encoder_block_shape_q,
|
||||||
|
const int decoder_block_shape_q,
|
||||||
|
const int group_size,
|
||||||
|
const int block_size,
|
||||||
const int decoder_step_token_num);
|
const int decoder_step_token_num);
|
||||||
|
|
||||||
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
std::vector<paddle::Tensor> GetPaddingOffset(const paddle::Tensor &input_ids,
|
||||||
|
@@ -77,6 +77,10 @@ class ForwardMeta:
|
|||||||
decoder_batch_ids: Optional[paddle.Tensor] = None
|
decoder_batch_ids: Optional[paddle.Tensor] = None
|
||||||
# Tile ID for each batch of the decoder. Used by attention backend.
|
# Tile ID for each batch of the decoder. Used by attention backend.
|
||||||
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
|
||||||
|
# The number of blocks that attention backend can use in decode stage
|
||||||
|
decoder_num_blocks_cpu: Optional[paddle.Tensor] = None
|
||||||
|
# Recorded multiple lengths related to prefill or decode
|
||||||
|
max_len_tensor_cpu: Optional[paddle.Tensor] = None
|
||||||
|
|
||||||
# Sequence length of encoder for ever batch
|
# Sequence length of encoder for ever batch
|
||||||
seq_lens_encoder: Optional[paddle.Tensor] = None
|
seq_lens_encoder: Optional[paddle.Tensor] = None
|
||||||
|
@@ -48,17 +48,13 @@ class AppendAttentionMetadata(AttentionMetadata):
|
|||||||
AppendAttentionMetadata
|
AppendAttentionMetadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_len_kv: paddle.Tensor = None
|
|
||||||
set_max_lengths: int = -1
|
|
||||||
encoder_batch_ids: paddle.Tensor = None
|
encoder_batch_ids: paddle.Tensor = None
|
||||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||||
encoder_num_blocks: paddle.Tensor = None
|
encoder_num_blocks: paddle.Tensor = None
|
||||||
kv_batch_ids: paddle.Tensor = None
|
kv_batch_ids: paddle.Tensor = None
|
||||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||||
kv_num_blocks: paddle.Tensor = None
|
kv_num_blocks: paddle.Tensor = None
|
||||||
decoder_batch_ids: paddle.Tensor = None
|
max_len_kv: paddle.Tensor = None
|
||||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
decoder_num_blocks: paddle.Tensor = None
|
|
||||||
|
|
||||||
_dtype: paddle.dtype = paddle.bfloat16
|
_dtype: paddle.dtype = paddle.bfloat16
|
||||||
encoder_max_partition_size: int = 32768
|
encoder_max_partition_size: int = 32768
|
||||||
@@ -66,8 +62,6 @@ class AppendAttentionMetadata(AttentionMetadata):
|
|||||||
block_tables: Optional[paddle.Tensor] = None
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
rotary_embs: Optional[paddle.Tensor] = None
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
attn_mask: Optional[paddle.Tensor] = None
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
encoder_block_shape_q: int = -1
|
|
||||||
decoder_block_shape_q: int = -1
|
|
||||||
_fuse_kernel_compute_dtype: str = "bf16"
|
_fuse_kernel_compute_dtype: str = "bf16"
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
@@ -89,6 +83,8 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
kv_num_heads: int,
|
kv_num_heads: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
|
encoder_block_shape_q: int = -1,
|
||||||
|
decoder_block_shape_q: int = -1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
AppendAttentionBackend __init__
|
AppendAttentionBackend __init__
|
||||||
@@ -110,9 +106,12 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.kv_num_heads: int = kv_num_heads
|
self.kv_num_heads: int = kv_num_heads
|
||||||
self.num_heads: int = num_heads
|
self.num_heads: int = num_heads
|
||||||
|
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||||
self.head_dim: int = fd_config.model_config.head_dim
|
self.head_dim: int = fd_config.model_config.head_dim
|
||||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||||
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
|
self.max_partition_size: int = int(os.getenv("FLAGS_max_partition_size", 32768))
|
||||||
|
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||||
|
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||||
|
|
||||||
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
|
||||||
|
|
||||||
@@ -126,8 +125,6 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||||
metadata = AppendAttentionMetadata()
|
metadata = AppendAttentionMetadata()
|
||||||
metadata.encoder_block_shape_q = 64
|
|
||||||
metadata.decoder_block_shape_q = 16
|
|
||||||
metadata.max_partition_size = self.max_partition_size
|
metadata.max_partition_size = self.max_partition_size
|
||||||
metadata.encoder_max_partition_size = self.max_seq_len
|
metadata.encoder_max_partition_size = self.max_seq_len
|
||||||
metadata._dtype = paddle.get_default_dtype()
|
metadata._dtype = paddle.get_default_dtype()
|
||||||
@@ -148,18 +145,18 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
metadata.kv_batch_ids,
|
metadata.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
metadata.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
metadata.kv_num_blocks,
|
||||||
metadata.decoder_batch_ids, # will copy to buffer
|
|
||||||
metadata.decoder_tile_ids_per_batch, # will copy to buffer
|
|
||||||
metadata.decoder_num_blocks,
|
|
||||||
metadata.max_len_kv,
|
metadata.max_len_kv,
|
||||||
metadata.set_max_lengths,
|
|
||||||
) = get_block_shape_and_split_kv_block(
|
) = get_block_shape_and_split_kv_block(
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
metadata.encoder_block_shape_q,
|
forward_meta.decoder_batch_ids,
|
||||||
metadata.decoder_block_shape_q,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
self.num_heads // self.kv_num_heads,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
|
forward_meta.max_len_tensor_cpu,
|
||||||
|
self.encoder_block_shape_q,
|
||||||
|
self.decoder_block_shape_q,
|
||||||
|
self.group_size,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
self.speculate_max_draft_token_num + 1,
|
self.speculate_max_draft_token_num + 1,
|
||||||
)
|
)
|
||||||
@@ -181,8 +178,6 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.attention_metadata: AttentionMetadata = metadata
|
self.attention_metadata: AttentionMetadata = metadata
|
||||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
|
||||||
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
|
||||||
|
|
||||||
def get_attntion_meta(self) -> AttentionMetadata:
|
def get_attntion_meta(self) -> AttentionMetadata:
|
||||||
"""get_attntion_meta"""
|
"""get_attntion_meta"""
|
||||||
@@ -249,10 +244,10 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
metadata.kv_batch_ids,
|
metadata.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
metadata.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
metadata.kv_num_blocks,
|
||||||
forward_meta.decoder_batch_ids, # from buffer
|
forward_meta.decoder_batch_ids,
|
||||||
forward_meta.decoder_tile_ids_per_batch, # from buffer
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
metadata.decoder_num_blocks,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
metadata.set_max_lengths,
|
forward_meta.max_len_tensor_cpu,
|
||||||
metadata.max_len_kv,
|
metadata.max_len_kv,
|
||||||
metadata.rotary_embs,
|
metadata.rotary_embs,
|
||||||
metadata.attn_mask,
|
metadata.attn_mask,
|
||||||
@@ -275,8 +270,8 @@ class AppendAttentionBackend(AttentionBackend):
|
|||||||
getattr(layer, "quant_max_bound", 0.0),
|
getattr(layer, "quant_max_bound", 0.0),
|
||||||
getattr(layer, "quant_min_bound", 0.0),
|
getattr(layer, "quant_min_bound", 0.0),
|
||||||
getattr(layer, "out_scale", -1.0),
|
getattr(layer, "out_scale", -1.0),
|
||||||
metadata.encoder_block_shape_q,
|
self.encoder_block_shape_q,
|
||||||
metadata.decoder_block_shape_q,
|
self.decoder_block_shape_q,
|
||||||
metadata.max_partition_size,
|
metadata.max_partition_size,
|
||||||
metadata.encoder_max_partition_size,
|
metadata.encoder_max_partition_size,
|
||||||
self.speculate_max_draft_token_num + 1,
|
self.speculate_max_draft_token_num + 1,
|
||||||
|
@@ -38,17 +38,12 @@ class BlockAttentionMetadata(AttentionMetadata):
|
|||||||
BlockAttentionMetadata
|
BlockAttentionMetadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_len_kv: paddle.Tensor = None
|
|
||||||
set_max_lengths: int = -1
|
|
||||||
encoder_batch_ids: paddle.Tensor = None
|
encoder_batch_ids: paddle.Tensor = None
|
||||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||||
encoder_num_blocks: paddle.Tensor = None
|
encoder_num_blocks: paddle.Tensor = None
|
||||||
kv_batch_ids: paddle.Tensor = None
|
kv_batch_ids: paddle.Tensor = None
|
||||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||||
kv_num_blocks: paddle.Tensor = None
|
kv_num_blocks: paddle.Tensor = None
|
||||||
decoder_batch_ids: paddle.Tensor = None
|
|
||||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
decoder_num_blocks: paddle.Tensor = None
|
|
||||||
|
|
||||||
_dtype: paddle.dtype = paddle.bfloat16
|
_dtype: paddle.dtype = paddle.bfloat16
|
||||||
encoder_max_partition_size: int = 32768
|
encoder_max_partition_size: int = 32768
|
||||||
@@ -56,8 +51,6 @@ class BlockAttentionMetadata(AttentionMetadata):
|
|||||||
block_tables: Optional[paddle.Tensor] = None
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
rotary_embs: Optional[paddle.Tensor] = None
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
attn_mask: Optional[paddle.Tensor] = None
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
encoder_block_shape_q: int = -1
|
|
||||||
decoder_block_shape_q: int = -1
|
|
||||||
_fuse_kernel_compute_dtype: str = "bf16"
|
_fuse_kernel_compute_dtype: str = "bf16"
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
@@ -79,6 +72,8 @@ class BlockAttentionBackend(AttentionBackend):
|
|||||||
kv_num_heads: int,
|
kv_num_heads: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
|
encoder_block_shape_q: int = -1,
|
||||||
|
decoder_block_shape_q: int = -1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
BlockAttentionBackend __init__
|
BlockAttentionBackend __init__
|
||||||
|
@@ -53,8 +53,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
FlashAttentionMetadata
|
FlashAttentionMetadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_len_kv: paddle.Tensor = None
|
|
||||||
set_max_lengths: int = -1
|
|
||||||
rotary_embs: Optional[paddle.Tensor] = None
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
block_tables: Optional[paddle.Tensor] = None
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
encoder_batch_ids: paddle.Tensor = None
|
encoder_batch_ids: paddle.Tensor = None
|
||||||
@@ -63,12 +61,6 @@ class FlashAttentionMetadata(AttentionMetadata):
|
|||||||
kv_batch_ids: paddle.Tensor = None
|
kv_batch_ids: paddle.Tensor = None
|
||||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||||
kv_num_blocks: paddle.Tensor = None
|
kv_num_blocks: paddle.Tensor = None
|
||||||
decoder_batch_ids: paddle.Tensor = None
|
|
||||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
decoder_num_blocks: paddle.Tensor = None
|
|
||||||
|
|
||||||
encoder_block_shape_q: int = -1
|
|
||||||
decoder_block_shape_q: int = -1
|
|
||||||
|
|
||||||
cu_seqlens_q: paddle.Tensor = None
|
cu_seqlens_q: paddle.Tensor = None
|
||||||
cu_seqlens_k: paddle.Tensor = None
|
cu_seqlens_k: paddle.Tensor = None
|
||||||
@@ -100,6 +92,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
kv_num_heads: int,
|
kv_num_heads: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
|
encoder_block_shape_q: int = -1,
|
||||||
|
decoder_block_shape_q: int = -1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
FlashAttentionBackend __init__
|
FlashAttentionBackend __init__
|
||||||
@@ -111,10 +105,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.kv_num_heads = kv_num_heads
|
self.kv_num_heads = kv_num_heads
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||||
self.head_dim = fd_config.model_config.head_dim
|
self.head_dim = fd_config.model_config.head_dim
|
||||||
self.attn_outputsize_tp = self.num_heads * self.head_dim
|
self.attn_outputsize_tp = self.num_heads * self.head_dim
|
||||||
self.block_size = fd_config.cache_config.block_size
|
self.block_size = fd_config.cache_config.block_size
|
||||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||||
|
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||||
|
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||||
|
|
||||||
self.speculative_method = fd_config.speculative_config.method
|
self.speculative_method = fd_config.speculative_config.method
|
||||||
self.use_speculate = self.speculative_method is not None
|
self.use_speculate = self.speculative_method is not None
|
||||||
@@ -176,8 +173,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
metadata = FlashAttentionMetadata()
|
metadata = FlashAttentionMetadata()
|
||||||
metadata.encoder_block_shape_q = 64
|
|
||||||
metadata.decoder_block_shape_q = 16
|
|
||||||
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
metadata.cu_seqlens_q = forward_meta.cu_seqlens_q
|
||||||
metadata.rotary_embs = forward_meta.rotary_embs
|
metadata.rotary_embs = forward_meta.rotary_embs
|
||||||
metadata.block_tables = forward_meta.block_tables
|
metadata.block_tables = forward_meta.block_tables
|
||||||
@@ -188,18 +183,18 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.kv_batch_ids,
|
metadata.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
metadata.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
metadata.kv_num_blocks,
|
||||||
metadata.decoder_batch_ids,
|
|
||||||
metadata.decoder_tile_ids_per_batch,
|
|
||||||
metadata.decoder_num_blocks,
|
|
||||||
metadata.max_len_kv,
|
metadata.max_len_kv,
|
||||||
metadata.set_max_lengths,
|
|
||||||
) = get_block_shape_and_split_kv_block(
|
) = get_block_shape_and_split_kv_block(
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
metadata.encoder_block_shape_q,
|
forward_meta.decoder_batch_ids,
|
||||||
metadata.decoder_block_shape_q,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
self.num_heads // self.kv_num_heads,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
|
forward_meta.max_len_tensor_cpu,
|
||||||
|
self.encoder_block_shape_q,
|
||||||
|
self.decoder_block_shape_q,
|
||||||
|
self.group_size,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
self.speculate_max_draft_token_num + 1,
|
self.speculate_max_draft_token_num + 1,
|
||||||
)
|
)
|
||||||
@@ -233,8 +228,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.rank, int(self.device_id), self.keep_pd_step_flag
|
self.rank, int(self.device_id), self.keep_pd_step_flag
|
||||||
)
|
)
|
||||||
self.attention_metadata = metadata
|
self.attention_metadata = metadata
|
||||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
|
||||||
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
|
||||||
|
|
||||||
def forward_mixed(
|
def forward_mixed(
|
||||||
self,
|
self,
|
||||||
@@ -291,8 +284,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
v,
|
v,
|
||||||
metadata.cu_seqlens_q,
|
metadata.cu_seqlens_q,
|
||||||
metadata.cu_seqlens_k,
|
metadata.cu_seqlens_k,
|
||||||
max_seqlen_q=metadata.set_max_lengths[0],
|
max_seqlen_q=forward_meta.max_len_tensor_cpu[0],
|
||||||
max_seqlen_k=metadata.set_max_lengths[3],
|
max_seqlen_k=forward_meta.max_len_tensor_cpu[3],
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
**self.flash_attn_kwargs,
|
**self.flash_attn_kwargs,
|
||||||
)[0].reshape([-1, self.attn_outputsize_tp])
|
)[0].reshape([-1, self.attn_outputsize_tp])
|
||||||
|
@@ -64,17 +64,13 @@ class MLAAttentionMetadata(AttentionMetadata):
|
|||||||
MLAAttentionMetadata for Multi-Layer Attention
|
MLAAttentionMetadata for Multi-Layer Attention
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_len_kv: paddle.Tensor = None
|
|
||||||
set_max_lengths: int = -1
|
|
||||||
encoder_batch_ids: paddle.Tensor = None
|
encoder_batch_ids: paddle.Tensor = None
|
||||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
encoder_tile_ids_per_batch: paddle.Tensor = None
|
||||||
encoder_num_blocks: paddle.Tensor = None
|
encoder_num_blocks: paddle.Tensor = None
|
||||||
kv_batch_ids: paddle.Tensor = None
|
kv_batch_ids: paddle.Tensor = None
|
||||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
kv_tile_ids_per_batch: paddle.Tensor = None
|
||||||
kv_num_blocks: paddle.Tensor = None
|
kv_num_blocks: paddle.Tensor = None
|
||||||
decoder_batch_ids: paddle.Tensor = None
|
max_len_kv: paddle.Tensor = None
|
||||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
decoder_num_blocks: paddle.Tensor = None
|
|
||||||
|
|
||||||
_dtype: paddle.dtype = paddle.bfloat16
|
_dtype: paddle.dtype = paddle.bfloat16
|
||||||
encoder_max_partition_size: int = 32768
|
encoder_max_partition_size: int = 32768
|
||||||
@@ -82,8 +78,6 @@ class MLAAttentionMetadata(AttentionMetadata):
|
|||||||
block_tables: Optional[paddle.Tensor] = None
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
rotary_embs: Optional[paddle.Tensor] = None
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
attn_mask: Optional[paddle.Tensor] = None
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
encoder_block_shape_q: int = -1
|
|
||||||
decoder_block_shape_q: int = -1
|
|
||||||
_fuse_kernel_compute_dtype: str = "bf16"
|
_fuse_kernel_compute_dtype: str = "bf16"
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
@@ -105,6 +99,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
kv_num_heads: int,
|
kv_num_heads: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
|
encoder_block_shape_q: int = -1,
|
||||||
|
decoder_block_shape_q: int = -1,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
MLAAttentionBackend __init__
|
MLAAttentionBackend __init__
|
||||||
@@ -128,8 +124,11 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.kv_num_heads: int = kv_num_heads
|
self.kv_num_heads: int = kv_num_heads
|
||||||
self.num_heads: int = num_heads
|
self.num_heads: int = num_heads
|
||||||
|
self.group_size: int = self.num_heads // self.kv_num_heads
|
||||||
self.head_dim: int = fd_config.model_config.head_dim
|
self.head_dim: int = fd_config.model_config.head_dim
|
||||||
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
self.num_layers: int = fd_config.model_config.num_hidden_layers
|
||||||
|
self.encoder_block_shape_q: int = encoder_block_shape_q
|
||||||
|
self.decoder_block_shape_q: int = decoder_block_shape_q
|
||||||
|
|
||||||
# For Multi Head Latent Attention
|
# For Multi Head Latent Attention
|
||||||
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
|
self.kv_lora_rank: int = fd_config.model_config.kv_lora_rank
|
||||||
@@ -152,8 +151,6 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
|
||||||
metadata = MLAAttentionMetadata()
|
metadata = MLAAttentionMetadata()
|
||||||
metadata.encoder_block_shape_q = 64
|
|
||||||
metadata.decoder_block_shape_q = 16
|
|
||||||
metadata.max_partition_size = 32768
|
metadata.max_partition_size = 32768
|
||||||
metadata.encoder_max_partition_size = self.max_seq_len
|
metadata.encoder_max_partition_size = self.max_seq_len
|
||||||
metadata._dtype = paddle.get_default_dtype()
|
metadata._dtype = paddle.get_default_dtype()
|
||||||
@@ -176,27 +173,25 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
metadata.kv_batch_ids,
|
metadata.kv_batch_ids,
|
||||||
metadata.kv_tile_ids_per_batch,
|
metadata.kv_tile_ids_per_batch,
|
||||||
metadata.kv_num_blocks,
|
metadata.kv_num_blocks,
|
||||||
metadata.decoder_batch_ids,
|
|
||||||
metadata.decoder_tile_ids_per_batch,
|
|
||||||
metadata.decoder_num_blocks,
|
|
||||||
metadata.max_len_kv,
|
metadata.max_len_kv,
|
||||||
metadata.set_max_lengths,
|
|
||||||
) = get_block_shape_and_split_kv_block(
|
) = get_block_shape_and_split_kv_block(
|
||||||
forward_meta.seq_lens_encoder,
|
forward_meta.seq_lens_encoder,
|
||||||
forward_meta.seq_lens_decoder,
|
forward_meta.seq_lens_decoder,
|
||||||
forward_meta.seq_lens_this_time,
|
forward_meta.seq_lens_this_time,
|
||||||
metadata.encoder_block_shape_q,
|
forward_meta.decoder_batch_ids,
|
||||||
metadata.decoder_block_shape_q,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
self.num_heads // self.kv_num_heads,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
|
forward_meta.max_len_tensor_cpu,
|
||||||
|
self.encoder_block_shape_q,
|
||||||
|
self.decoder_block_shape_q,
|
||||||
|
self.group_size,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
self.speculate_max_draft_token_num + 1,
|
self.speculate_max_draft_token_num + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# MLA
|
# MLA
|
||||||
metadata.max_enc_len_this_time = metadata.set_max_lengths[1]
|
metadata.max_enc_len_this_time = forward_meta.max_len_tensor_cpu[1]
|
||||||
metadata.max_dec_len_this_time = metadata.set_max_lengths[2]
|
metadata.max_dec_len_this_time = forward_meta.max_len_tensor_cpu[2]
|
||||||
forward_meta.max_enc_len_this_time = metadata.set_max_lengths[1]
|
|
||||||
forward_meta.max_dec_len_this_time = metadata.set_max_lengths[2]
|
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
metadata.kv_signal_data_list = [None] * self.num_layers
|
metadata.kv_signal_data_list = [None] * self.num_layers
|
||||||
@@ -216,9 +211,6 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
self.attention_metadata: AttentionMetadata = metadata
|
self.attention_metadata: AttentionMetadata = metadata
|
||||||
|
|
||||||
forward_meta.decoder_batch_ids.copy_(metadata.decoder_batch_ids, False)
|
|
||||||
forward_meta.decoder_tile_ids_per_batch.copy_(metadata.decoder_tile_ids_per_batch, False)
|
|
||||||
|
|
||||||
def get_attntion_meta(self) -> AttentionMetadata:
|
def get_attntion_meta(self) -> AttentionMetadata:
|
||||||
"""get_attntion_meta"""
|
"""get_attntion_meta"""
|
||||||
return self.attention_metadata
|
return self.attention_metadata
|
||||||
@@ -354,8 +346,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
metadata.kv_num_blocks,
|
metadata.kv_num_blocks,
|
||||||
forward_meta.decoder_batch_ids,
|
forward_meta.decoder_batch_ids,
|
||||||
forward_meta.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
metadata.decoder_num_blocks,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
metadata.max_enc_len_this_time,
|
metadata.max_enc_len_this_time,
|
||||||
metadata.max_dec_len_this_time,
|
metadata.max_dec_len_this_time,
|
||||||
metadata.max_len_kv,
|
metadata.max_len_kv,
|
||||||
@@ -476,8 +468,8 @@ class MLAAttentionBackend(AttentionBackend):
|
|||||||
metadata.kv_num_blocks,
|
metadata.kv_num_blocks,
|
||||||
forward_meta.decoder_batch_ids,
|
forward_meta.decoder_batch_ids,
|
||||||
forward_meta.decoder_tile_ids_per_batch,
|
forward_meta.decoder_tile_ids_per_batch,
|
||||||
metadata.decoder_num_blocks,
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
metadata.decoder_num_blocks, # PaddleNLP 传入的是 decoder_num_blocks_cpu
|
forward_meta.decoder_num_blocks_cpu,
|
||||||
metadata.max_enc_len_this_time,
|
metadata.max_enc_len_this_time,
|
||||||
metadata.max_dec_len_this_time,
|
metadata.max_dec_len_this_time,
|
||||||
metadata.max_len_kv,
|
metadata.max_len_kv,
|
||||||
|
@@ -28,6 +28,10 @@ def get_block_shape_and_split_kv_block(
|
|||||||
seq_lens_encoder: paddle.Tensor,
|
seq_lens_encoder: paddle.Tensor,
|
||||||
seq_lens_decoder: paddle.Tensor,
|
seq_lens_decoder: paddle.Tensor,
|
||||||
seq_lens_this_time: paddle.Tensor,
|
seq_lens_this_time: paddle.Tensor,
|
||||||
|
decoder_batch_ids: paddle.Tensor,
|
||||||
|
decoder_tile_ids_per_batch: paddle.Tensor,
|
||||||
|
decoder_num_blocks_x_cpu: paddle.Tensor,
|
||||||
|
max_len_tensor_cpu: paddle.Tensor,
|
||||||
encoder_block_shape_q: int,
|
encoder_block_shape_q: int,
|
||||||
decoder_block_shape_q: int,
|
decoder_block_shape_q: int,
|
||||||
group_size: int,
|
group_size: int,
|
||||||
@@ -45,15 +49,15 @@ def get_block_shape_and_split_kv_block(
|
|||||||
kv_batch_ids,
|
kv_batch_ids,
|
||||||
kv_tile_ids_per_batch,
|
kv_tile_ids_per_batch,
|
||||||
kv_num_blocks,
|
kv_num_blocks,
|
||||||
decoder_batch_ids,
|
max_len_kv_cpu,
|
||||||
decoder_tile_ids_per_batch,
|
|
||||||
decoder_num_blocks,
|
|
||||||
max_len_kv,
|
|
||||||
set_max_lengths,
|
|
||||||
) = get_block_shape_and_split_kv_block_cuda(
|
) = get_block_shape_and_split_kv_block_cuda(
|
||||||
seq_lens_encoder,
|
seq_lens_encoder,
|
||||||
seq_lens_decoder,
|
seq_lens_decoder,
|
||||||
seq_lens_this_time,
|
seq_lens_this_time,
|
||||||
|
decoder_batch_ids,
|
||||||
|
decoder_tile_ids_per_batch,
|
||||||
|
decoder_num_blocks_x_cpu,
|
||||||
|
max_len_tensor_cpu,
|
||||||
encoder_block_shape_q,
|
encoder_block_shape_q,
|
||||||
decoder_block_shape_q,
|
decoder_block_shape_q,
|
||||||
group_size,
|
group_size,
|
||||||
@@ -67,11 +71,7 @@ def get_block_shape_and_split_kv_block(
|
|||||||
kv_batch_ids,
|
kv_batch_ids,
|
||||||
kv_tile_ids_per_batch,
|
kv_tile_ids_per_batch,
|
||||||
kv_num_blocks,
|
kv_num_blocks,
|
||||||
decoder_batch_ids,
|
max_len_kv_cpu,
|
||||||
decoder_tile_ids_per_batch,
|
|
||||||
decoder_num_blocks,
|
|
||||||
max_len_kv,
|
|
||||||
set_max_lengths,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@@ -44,26 +44,13 @@ class XPUAttentionMetadata(AttentionMetadata):
|
|||||||
XPUAttentionMetadata
|
XPUAttentionMetadata
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_len_kv: paddle.Tensor = None
|
|
||||||
set_max_lengths: int = -1
|
|
||||||
encoder_batch_ids: paddle.Tensor = None
|
|
||||||
encoder_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
encoder_num_blocks: paddle.Tensor = None
|
|
||||||
kv_batch_ids: paddle.Tensor = None
|
|
||||||
kv_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
kv_num_blocks: paddle.Tensor = None
|
|
||||||
decoder_batch_ids: paddle.Tensor = None
|
|
||||||
decoder_tile_ids_per_batch: paddle.Tensor = None
|
|
||||||
decoder_num_blocks: paddle.Tensor = None
|
|
||||||
|
|
||||||
_dtype: paddle.dtype = paddle.bfloat16
|
_dtype: paddle.dtype = paddle.bfloat16
|
||||||
encoder_max_partition_size: int = 32768
|
encoder_max_partition_size: int = 32768
|
||||||
max_partition_size: int = 32768
|
max_partition_size: int = 32768
|
||||||
block_tables: Optional[paddle.Tensor] = None
|
block_tables: Optional[paddle.Tensor] = None
|
||||||
rotary_embs: Optional[paddle.Tensor] = None
|
rotary_embs: Optional[paddle.Tensor] = None
|
||||||
attn_mask: Optional[paddle.Tensor] = None
|
attn_mask: Optional[paddle.Tensor] = None
|
||||||
encoder_block_shape_q: int = -1
|
|
||||||
decoder_block_shape_q: int = -1
|
|
||||||
_fuse_kernel_compute_dtype: str = "bf16"
|
_fuse_kernel_compute_dtype: str = "bf16"
|
||||||
|
|
||||||
# pd_disaggregation
|
# pd_disaggregation
|
||||||
@@ -91,7 +78,6 @@ class XPUAttentionBackend(AttentionBackend):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attention_metadata: XPUAttentionMetadata = None
|
self.attention_metadata: XPUAttentionMetadata = None
|
||||||
# TODO(gongshaotian): Use fd_config parameters in the correct location
|
|
||||||
self.block_size: int = fd_config.cache_config.block_size
|
self.block_size: int = fd_config.cache_config.block_size
|
||||||
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
self.max_seq_len: int = fd_config.parallel_config.max_model_len
|
||||||
self.rope_theta: float = (
|
self.rope_theta: float = (
|
||||||
@@ -99,9 +85,6 @@ class XPUAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
|
||||||
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
self.causal: bool = getattr(fd_config.model_config, "causal", True)
|
||||||
# self.speculate_method = fd_config.parallel_config.speculate_method
|
|
||||||
# self.use_speculate = self.speculate_method is not None
|
|
||||||
# self.speculate_max_draft_token_num = fd_config.parallel_config.speculate_max_draft_tokens
|
|
||||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||||
|
|
||||||
@@ -117,8 +100,6 @@ class XPUAttentionBackend(AttentionBackend):
|
|||||||
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
def init_attention_metadata(self, forward_meta: ForwardMeta):
|
||||||
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
|
||||||
metadata = XPUAttentionMetadata()
|
metadata = XPUAttentionMetadata()
|
||||||
metadata.encoder_block_shape_q = 64
|
|
||||||
metadata.decoder_block_shape_q = 16
|
|
||||||
metadata.max_partition_size = 32768
|
metadata.max_partition_size = 32768
|
||||||
metadata.encoder_max_partition_size = 32768
|
metadata.encoder_max_partition_size = 32768
|
||||||
metadata._dtype = paddle.get_default_dtype()
|
metadata._dtype = paddle.get_default_dtype()
|
||||||
|
@@ -184,13 +184,26 @@ class MTPProposer(Proposer):
|
|||||||
"""
|
"""
|
||||||
assert len(self.attn_backends) == 0
|
assert len(self.attn_backends) == 0
|
||||||
|
|
||||||
# TODO(gongshaotian): Get rank from config
|
|
||||||
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
|
num_heads = self.model_config.num_attention_heads // self.parallel_config.tensor_parallel_size
|
||||||
self.model_config.kv_num_heads = (
|
self.model_config.kv_num_heads = max(
|
||||||
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size
|
1,
|
||||||
|
int(self.model_config.num_key_value_heads) // self.parallel_config.tensor_parallel_size,
|
||||||
)
|
)
|
||||||
head_dim = self.model_config.head_dim
|
head_dim = self.model_config.head_dim
|
||||||
|
|
||||||
|
# Initialize AttentionBackend buffers
|
||||||
|
encoder_block_shape_q = 64
|
||||||
|
decoder_block_shape_q = 16
|
||||||
|
|
||||||
|
self.model_inputs["decoder_batch_ids"] = paddle.zeros_like(self.main_model_inputs["decoder_batch_ids"])
|
||||||
|
self.model_inputs["decoder_tile_ids_per_batch"] = paddle.zeros_like(
|
||||||
|
self.main_model_inputs["decoder_tile_ids_per_batch"]
|
||||||
|
)
|
||||||
|
self.model_inputs["decoder_num_blocks_cpu"] = paddle.zeros_like(
|
||||||
|
self.main_model_inputs["decoder_num_blocks_cpu"]
|
||||||
|
).pin_memory()
|
||||||
|
self.model_inputs["max_len_tensor_cpu"] = paddle.zeros_like(self.main_model_inputs["max_len_tensor_cpu"]).cpu()
|
||||||
|
|
||||||
# Get the attention backend
|
# Get the attention backend
|
||||||
attn_cls = get_attention_backend()
|
attn_cls = get_attention_backend()
|
||||||
attn_backend = attn_cls(
|
attn_backend = attn_cls(
|
||||||
@@ -198,6 +211,8 @@ class MTPProposer(Proposer):
|
|||||||
kv_num_heads=self.model_config.kv_num_heads,
|
kv_num_heads=self.model_config.kv_num_heads,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
|
encoder_block_shape_q=encoder_block_shape_q,
|
||||||
|
decoder_block_shape_q=decoder_block_shape_q,
|
||||||
)
|
)
|
||||||
if attn_backend is None:
|
if attn_backend is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -293,6 +308,12 @@ class MTPProposer(Proposer):
|
|||||||
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
|
self.model_inputs["base_model_draft_tokens"] = self.main_model_inputs["draft_tokens"]
|
||||||
self.model_inputs["substep"] = 0
|
self.model_inputs["substep"] = 0
|
||||||
|
|
||||||
|
# Declare AttentionBackend buffers
|
||||||
|
self.model_inputs["decoder_batch_ids"] = None
|
||||||
|
self.model_inputs["decoder_tile_ids_per_batch"] = None
|
||||||
|
self.model_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||||
|
self.model_inputs["max_len_tensor_cpu"] = None # CPU
|
||||||
|
|
||||||
# Input tokens
|
# Input tokens
|
||||||
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
|
self.model_inputs["draft_tokens"] = paddle.full(shape=[self.max_num_seqs, 2], fill_value=-1, dtype="int64")
|
||||||
|
|
||||||
@@ -405,6 +426,8 @@ class MTPProposer(Proposer):
|
|||||||
attn_backend=self.attn_backends[0],
|
attn_backend=self.attn_backends[0],
|
||||||
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
|
decoder_batch_ids=self.model_inputs["decoder_batch_ids"],
|
||||||
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
|
decoder_tile_ids_per_batch=self.model_inputs["decoder_tile_ids_per_batch"],
|
||||||
|
decoder_num_blocks_cpu=self.model_inputs["decoder_num_blocks_cpu"],
|
||||||
|
max_len_tensor_cpu=self.model_inputs["max_len_tensor_cpu"],
|
||||||
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
|
seq_lens_encoder=self.model_inputs["seq_lens_encoder"],
|
||||||
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
|
seq_lens_decoder=self.model_inputs["seq_lens_decoder"],
|
||||||
seq_lens_this_time=self.model_inputs["seq_lens_this_time"],
|
seq_lens_this_time=self.model_inputs["seq_lens_this_time"],
|
||||||
|
@@ -417,9 +417,11 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
# AttentionBackend buffers
|
# Declare AttentionBackend buffers
|
||||||
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["decoder_batch_ids"] = None
|
||||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||||
|
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||||
|
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||||
|
|
||||||
# Initialize rotary position embedding
|
# Initialize rotary position embedding
|
||||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||||
@@ -579,6 +581,8 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
attn_backend=self.attn_backends[0],
|
attn_backend=self.attn_backends[0],
|
||||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||||
|
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
|
||||||
|
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
|
||||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||||
@@ -655,6 +659,18 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
head_dim = self.model_config.head_dim
|
head_dim = self.model_config.head_dim
|
||||||
|
|
||||||
|
# Initialize AttentionBackend buffers
|
||||||
|
encoder_block_shape_q = 64
|
||||||
|
decoder_block_shape_q = 16
|
||||||
|
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||||
|
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||||
|
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
||||||
|
)
|
||||||
|
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||||
|
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
# Get the attention backend
|
# Get the attention backend
|
||||||
attn_cls = get_attention_backend()
|
attn_cls = get_attention_backend()
|
||||||
attn_backend = attn_cls(
|
attn_backend = attn_cls(
|
||||||
@@ -662,6 +678,8 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
kv_num_heads=self.model_config.kv_num_heads,
|
kv_num_heads=self.model_config.kv_num_heads,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
|
encoder_block_shape_q=encoder_block_shape_q,
|
||||||
|
decoder_block_shape_q=decoder_block_shape_q,
|
||||||
)
|
)
|
||||||
if attn_backend is None:
|
if attn_backend is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
@@ -1179,9 +1197,5 @@ class GCUModelRunner(ModelRunnerBase):
|
|||||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||||
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
||||||
"""
|
"""
|
||||||
# TODO(gongshaotian): Use more efficient implementation
|
# In init_attention_metadata, the decode buffer has already been cleared
|
||||||
if self.forward_meta.step_use_cudagraph:
|
return
|
||||||
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
|
|
||||||
for i in range(1, num_empty_batch + 1):
|
|
||||||
self.forward_meta.decoder_batch_ids[-i] = 0
|
|
||||||
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
|
|
||||||
|
@@ -610,9 +610,7 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["step_seq_lens_decoder"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
self.share_inputs["prompt_lens"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||||
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
self.share_inputs["step_idx"] = paddle.full([max_num_seqs, 1], 0, dtype="int64")
|
||||||
self.share_inputs["not_need_stop"] = paddle.full(
|
self.share_inputs["not_need_stop"] = paddle.full([1], False, dtype="bool").cpu()
|
||||||
[1], False, dtype="bool"
|
|
||||||
).cpu() # TODO(gongshaotian): move to pinnd memory
|
|
||||||
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
self.share_inputs["stop_flags"] = paddle.full([max_num_seqs, 1], True, dtype="bool")
|
||||||
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
|
self.share_inputs["stop_nums"] = paddle.full([1], max_num_seqs, dtype="int64")
|
||||||
|
|
||||||
@@ -643,9 +641,12 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["batch_id_per_token"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
# AttentionBackend buffers
|
|
||||||
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
# Declare AttentionBackend buffers
|
||||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["decoder_batch_ids"] = None
|
||||||
|
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||||
|
self.share_inputs["decoder_num_blocks_cpu"] = None # Pinning Memory
|
||||||
|
self.share_inputs["max_len_tensor_cpu"] = None # CPU
|
||||||
|
|
||||||
# Initialize rotary position embedding
|
# Initialize rotary position embedding
|
||||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||||
@@ -845,6 +846,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
attn_backend=self.attn_backends[0],
|
attn_backend=self.attn_backends[0],
|
||||||
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
decoder_batch_ids=self.share_inputs["decoder_batch_ids"],
|
||||||
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
decoder_tile_ids_per_batch=self.share_inputs["decoder_tile_ids_per_batch"],
|
||||||
|
decoder_num_blocks_cpu=self.share_inputs["decoder_num_blocks_cpu"],
|
||||||
|
max_len_tensor_cpu=self.share_inputs["max_len_tensor_cpu"],
|
||||||
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
seq_lens_encoder=self.share_inputs["seq_lens_encoder"],
|
||||||
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
seq_lens_decoder=self.share_inputs["seq_lens_decoder"],
|
||||||
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
|
||||||
@@ -856,7 +859,6 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Update Batch type for cuda graph
|
# Update Batch type for cuda graph
|
||||||
# TODO(gongshaotian): Use seq_lens_encoder to set is_decode_batch
|
|
||||||
only_decode_batch = True
|
only_decode_batch = True
|
||||||
prefill_exists = None
|
prefill_exists = None
|
||||||
# mix ep in single node
|
# mix ep in single node
|
||||||
@@ -946,6 +948,18 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
)
|
)
|
||||||
head_dim = self.model_config.head_dim
|
head_dim = self.model_config.head_dim
|
||||||
|
|
||||||
|
# Initialize AttentionBackend buffers
|
||||||
|
encoder_block_shape_q = 64
|
||||||
|
decoder_block_shape_q = 16
|
||||||
|
decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1
|
||||||
|
decode_max_tile_size = self.parallel_config.max_num_seqs * np.ceil(
|
||||||
|
(decoder_step_token_num * np.ceil(num_heads / self.model_config.kv_num_heads)) / decoder_block_shape_q
|
||||||
|
)
|
||||||
|
self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||||
|
self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||||
|
self.share_inputs["max_len_tensor_cpu"] = paddle.full([8], 0, dtype="int32").cpu()
|
||||||
|
|
||||||
# Get the attention backend
|
# Get the attention backend
|
||||||
attn_cls = get_attention_backend()
|
attn_cls = get_attention_backend()
|
||||||
attn_backend = attn_cls(
|
attn_backend = attn_cls(
|
||||||
@@ -953,6 +967,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
kv_num_heads=self.model_config.kv_num_heads,
|
kv_num_heads=self.model_config.kv_num_heads,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
|
encoder_block_shape_q=encoder_block_shape_q,
|
||||||
|
decoder_block_shape_q=decoder_block_shape_q,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.attn_backends.append(attn_backend)
|
self.attn_backends.append(attn_backend)
|
||||||
@@ -1527,12 +1543,8 @@ class GPUModelRunner(ModelRunnerBase):
|
|||||||
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
Clean buffers used for the CUDA graph when replaying the CUDA graph with the padded batch.
|
||||||
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
In FastDeploy, almost all input tensors have a buffer. So, just keep the buffer clean when replaying the CUDA graph with the padded batch.
|
||||||
"""
|
"""
|
||||||
# TODO(gongshaotian): Use more efficient implementation
|
# In init_attention_metadata, the decode buffer has already been cleared
|
||||||
if self.forward_meta.step_use_cudagraph:
|
return
|
||||||
num_empty_batch = (self.forward_meta.seq_lens_this_time == 0).sum()
|
|
||||||
for i in range(1, num_empty_batch + 1):
|
|
||||||
self.forward_meta.decoder_batch_ids[-i] = 0
|
|
||||||
self.forward_meta.decoder_tile_ids_per_batch[-i] = 0
|
|
||||||
|
|
||||||
def _init_image_preprocess(self) -> None:
|
def _init_image_preprocess(self) -> None:
|
||||||
processor = DataProcessor(
|
processor = DataProcessor(
|
||||||
|
@@ -383,8 +383,8 @@ class IluvatarModelRunner(ModelRunnerBase):
|
|||||||
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["cu_seqlens_q"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["cu_seqlens_k"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
||||||
# AttentionBackend buffers
|
# AttentionBackend buffers
|
||||||
self.share_inputs["decoder_batch_ids"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["decoder_batch_ids"] = None
|
||||||
self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([max_num_seqs, 1], 0, dtype="int32")
|
self.share_inputs["decoder_tile_ids_per_batch"] = None
|
||||||
|
|
||||||
# Initialize rotary position embedding
|
# Initialize rotary position embedding
|
||||||
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
tmp_position_ids = paddle.arange(self.parallel_config.max_model_len).reshape((1, -1))
|
||||||
|
Reference in New Issue
Block a user