diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 23a05590a..efb76768f 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -34,6 +34,8 @@ from fastdeploy.model_executor.layers.attention.ops import ( if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta +import numpy as np + from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( @@ -62,6 +64,47 @@ class AppendAttentionMetadata(AttentionMetadata): kv_signal_data_list: List[Optional[paddle.Tensor]] = field(default_factory=list) +def allocate_launch_related_buffer( + max_batch_size, + max_model_len, + encoder_block_shape_q, + decoder_block_shape_q, + decoder_step_token_num, + num_heads, + kv_num_heads, + block_size, +): + # Initialize AttentionBackend buffers + group_size = np.ceil(num_heads / kv_num_heads) + + # NOTE: (changwenbin) When using auto_chunk, + # decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K. + decode_max_tile_size = ( + 1024 * max_batch_size * np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q) + ) + encode_max_tile_size = max_batch_size * np.ceil((max_model_len * group_size) / encoder_block_shape_q) + kv_max_tile_size = max_batch_size * np.ceil(max_model_len / block_size) + res = {} + res["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + res["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") + res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() + # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, + # adapted to cudagraph. + res["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32") + res["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32") + res["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu() + + res["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + res["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") + res["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + res["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + res["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") + res["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + + return res + + class AppendAttentionBackend(AttentionBackend): """ AppendAttentionBackend backend implementation. diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index d47066011..3c0f1a5ce 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -38,6 +38,9 @@ from fastdeploy.model_executor.guided_decoding import ( get_guided_backend, ) from fastdeploy.model_executor.layers.attention import get_attention_backend +from fastdeploy.model_executor.layers.attention.append_attn_backend import ( + allocate_launch_related_buffer, +) from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, ) @@ -1497,41 +1500,20 @@ class GPUModelRunner(ModelRunnerBase): ) head_dim = self.model_config.head_dim - # Initialize AttentionBackend buffers encoder_block_shape_q = 64 decoder_block_shape_q = 16 - decoder_step_token_num = self.speculative_config.num_speculative_tokens + 1 - group_size = np.ceil(num_heads / self.model_config.kv_num_heads) - # NOTE: (changwenbin) When using auto_chunk, - # decode_max_tile_size must take into account the maximum case, where *1024 can cover 128K. - decode_max_tile_size = ( - 1024 - * self.scheduler_config.max_num_seqs - * np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q) + res_buffer = allocate_launch_related_buffer( + max_batch_size=self.scheduler_config.max_num_seqs, + max_model_len=self.model_config.max_model_len, + encoder_block_shape_q=encoder_block_shape_q, + decoder_block_shape_q=decoder_block_shape_q, + decoder_step_token_num=self.speculative_config.num_speculative_tokens + 1, + num_heads=num_heads, + kv_num_heads=self.model_config.kv_num_heads, + block_size=self.fd_config.cache_config.block_size, ) - encode_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil( - (self.model_config.max_model_len * group_size) / encoder_block_shape_q - ) - kv_max_tile_size = self.scheduler_config.max_num_seqs * np.ceil( - self.model_config.max_model_len / self.fd_config.cache_config.block_size - ) - self.share_inputs["decoder_batch_ids"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") - self.share_inputs["decoder_tile_ids_per_batch"] = paddle.full([int(decode_max_tile_size)], 0, dtype="int32") - self.share_inputs["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory() - # NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor, - # adapted to cudagraph. - self.share_inputs["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32") - self.share_inputs["decoder_chunk_size_device"] = paddle.full([1], 64, dtype="int32") - self.share_inputs["max_len_tensor_cpu"] = paddle.full([9], 0, dtype="int32").cpu() - - self.share_inputs["encoder_batch_ids"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") - self.share_inputs["encoder_tile_ids_per_batch"] = paddle.full([int(encode_max_tile_size)], 0, dtype="int32") - self.share_inputs["encoder_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() - - self.share_inputs["kv_batch_ids"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") - self.share_inputs["kv_tile_ids_per_batch"] = paddle.full([int(kv_max_tile_size)], 0, dtype="int32") - self.share_inputs["kv_num_blocks_x_cpu"] = paddle.full([1], 0, dtype="int32").cpu() + self.share_inputs.update(res_buffer) # Get the attention backend attn_cls = get_attention_backend() diff --git a/tests/layers/test_attention_layer.py b/tests/layers/test_attention_layer.py index 147dac6f7..11911576a 100644 --- a/tests/layers/test_attention_layer.py +++ b/tests/layers/test_attention_layer.py @@ -22,7 +22,6 @@ import time import types import unittest -import numpy as np import paddle from paddle import nn @@ -44,6 +43,9 @@ from fastdeploy.model_executor.layers.attention import ( AttentionBackend, get_attention_backend, ) +from fastdeploy.model_executor.layers.attention.append_attn_backend import ( + allocate_launch_related_buffer, +) from fastdeploy.model_executor.layers.quantization import parse_quant_config from fastdeploy.model_executor.layers.rotary_embedding import get_rope from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_Attention @@ -193,38 +195,6 @@ class TestAttentionPerformance(unittest.TestCase): } return state_dict - def create_attn_backend_buffers(self, m_config: ModelConfig, batch_size: int, block_size: int) -> dict: - """ - Pre-allocates metadata buffers required by the Attention backend. - """ - encoder_block_shape_q = 64 - decoder_block_shape_q = 16 - decoder_step_token_num = 1 - num_heads = m_config.num_attention_heads - kv_num_heads = m_config.num_key_value_heads - group_size = np.ceil(num_heads / kv_num_heads) - - decode_max_tile_size = ( - 1024 * batch_size * np.ceil((decoder_step_token_num * group_size) / decoder_block_shape_q) - ) - encode_max_tile_size = batch_size * np.ceil((m_config.max_model_len * group_size) / encoder_block_shape_q) - kv_max_tile_size = batch_size * np.ceil(m_config.max_model_len / block_size) - - return { - "decoder_batch_ids": paddle.full([int(decode_max_tile_size)], 0, dtype="int32"), - "decoder_tile_ids_per_batch": paddle.full([int(decode_max_tile_size)], 0, dtype="int32"), - "decoder_num_blocks_cpu": paddle.full([1], 0, dtype="int32").pin_memory(), - "decoder_num_blocks_device": paddle.full([1], 0, dtype="int32"), - "decoder_chunk_size_device": paddle.full([1], 64, dtype="int32"), - "max_len_tensor_cpu": paddle.full([8], 0, dtype="int32").cpu(), - "encoder_batch_ids": paddle.full([int(encode_max_tile_size)], 0, dtype="int32"), - "encoder_tile_ids_per_batch": paddle.full([int(encode_max_tile_size)], 0, dtype="int32"), - "encoder_num_blocks_x_cpu": paddle.full([1], 0, dtype="int32").cpu(), - "kv_batch_ids": paddle.full([int(kv_max_tile_size)], 0, dtype="int32"), - "kv_tile_ids_per_batch": paddle.full([int(kv_max_tile_size)], 0, dtype="int32"), - "kv_num_blocks_x_cpu": paddle.full([1], 0, dtype="int32").cpu(), - } - def create_forward_meta( self, batch_size: int, @@ -252,8 +222,15 @@ class TestAttentionPerformance(unittest.TestCase): else: raise ValueError(f"Unsupported ForwardMode: {mode}") - attn_backend_buffers = self.create_attn_backend_buffers( - fd_config.model_config, batch_size, fd_config.cache_config.block_size + attn_backend_buffers = allocate_launch_related_buffer( + max_batch_size=batch_size, + max_model_len=fd_config.model_config.max_model_len, + encoder_block_shape_q=64, + decoder_block_shape_q=16, + decoder_step_token_num=1, + num_heads=fd_config.model_config.num_attention_heads, + kv_num_heads=fd_config.model_config.num_key_value_heads, + block_size=fd_config.cache_config.block_size, ) if existing_caches is None: