mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[ATTENTION] make buffer alloc as a function (#4945)
This commit is contained in:
@@ -34,6 +34,8 @@ from fastdeploy.model_executor.layers.attention.ops import (
|
||||
if TYPE_CHECKING:
|
||||
from fastdeploy.model_executor.forward_meta import ForwardMeta
|
||||
|
||||
import numpy as np
|
||||
|
||||
from fastdeploy.config import FDConfig
|
||||
from fastdeploy.model_executor.layers.attention.attention import Attention
|
||||
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
@@ -62,6 +64,47 @@ class AppendAttentionMetadata(AttentionMetadata):
|
||||
kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list)
|
||||
|
||||
|
||||
def allocate_launch_related_buffer(
|
||||
max_batch_size,
|
||||
max_model_len,
|
||||
encoder_block_shape_q,
|
||||
decoder_block_shape_q,
|
||||
decoder_step_token_num,
|
||||
num_heads,
|
||||
kv_num_heads,
|
||||
block_size,
|
||||
):
|
||||
# Initialize AttentionBackend buffers
|
||||
group_size = np.ceil(num_heads / kv_num_heads)
|
||||
|
||||
# NOTE: (changwenbin) When using auto_chunk,
|
||||
# decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K.
|
||||
decode_max_tile_size = (
|
||||
1024 * max_batch_size * np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q)
|
||||
)
|
||||
encode_max_tile_size = max_batch_size * np.ceil((max_model_len * group_size) / encoder_block_shape_q)
|
||||
kv_max_tile_size = max_batch_size * np.ceil(max_model_len / block_size)
|
||||
res = {}
|
||||
res["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
res["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32")
|
||||
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
|
||||
# adapted to cudagraph.
|
||||
res["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
|
||||
res["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32")
|
||||
res["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu()
|
||||
|
||||
res["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
res["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32")
|
||||
res["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
res["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
res["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32")
|
||||
res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class AppendAttentionBackend(AttentionBackend):
|
||||
"""
|
||||
AppendAttentionBackend backend implementation.
|
||||
|
||||
Reference in New Issue
Block a user