[Metax] optimize mla attention (#5258)

This commit is contained in:
xiaozude
2025-12-09 11:18:19 +08:00
committed by GitHub
parent 5d9b5e4a5b
commit c06a6234b9
6 changed files with 1026 additions and 377 deletions

View File

@@ -162,12 +162,11 @@ function copy_ops(){
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
if [ "$is_maca" = "True" ]; then
DEVICE_TYPE="metax_gpu"
mkdir -p ../fastdeploy/model_executor/ops/base
cp -r ${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/gpu
echo -e "MACA ops have been copy to fastdeploy"
return
fi
is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"`
if [ "$is_intel_hpu" = "True" ]; then
DEVICE_TYPE="intel-hpu"

View File

@@ -43,6 +43,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionMetadata,
)
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
from fastdeploy.platforms import current_platform
@dataclass
@@ -87,7 +88,10 @@ def allocate_launch_related_buffer(
res = {}
res["decoder_batch_ids"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
res["decoder_tile_ids_per_batch"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
if current_platform.is_maca():
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
else:
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
# adapted to cudagraph.
res["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")

View File

@@ -206,6 +206,9 @@ class MetaxMLAAttentionBackend(AttentionBackend):
self.seq_lens = seq_lens_decoder + seq_lens_this_time
self.block_tables = forward_meta.block_tables[non_zero_index]
self.tile_scheduler_metadata = None
self.num_splits = None
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
return self.attention_metadata
@@ -250,13 +253,13 @@ class MetaxMLAAttentionBackend(AttentionBackend):
]
)
query = query.reshape([-1, seq_len_q, num_heads_q, head_dim_qk])
query = query.reshape_([-1, seq_len_q, num_heads_q, head_dim_qk])
tile_scheduler_metadata, num_splits = get_mla_metadata(
self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
)
assert tile_scheduler_metadata.shape[0] != 0
if self.tile_scheduler_metadata is None or self.num_splits is None:
self.tile_scheduler_metadata, self.num_splits = get_mla_metadata(
self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
)
assert self.tile_scheduler_metadata.shape[0] != 0
out = flash_mla_with_kvcache(
query,
@@ -264,8 +267,8 @@ class MetaxMLAAttentionBackend(AttentionBackend):
self.block_tables,
self.seq_lens,
head_dim_v,
tile_scheduler_metadata,
num_splits,
self.tile_scheduler_metadata,
self.num_splits,
softmax_scale=self.attn_softmax_scale,
causal=self.causal,
)[0]
@@ -273,7 +276,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
if seq_len_q != self.seq_lens_this_time_min:
out = paddle.concat([paddle.split(x, [n, seq_len_q - n])[0] for x, n in zip(out, self.seq_lens_this_time)])
else:
out = out.reshape([-1, num_heads_q, head_dim_v])
out = out.reshape_([-1, num_heads_q, head_dim_v])
return out

View File

@@ -411,14 +411,14 @@ class DeepseekV3MLAAttention(nn.Layer):
forward_meta=forward_meta,
)
fmha_out_decode = fmha_out_decode.reshape([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose(
fmha_out_decode = fmha_out_decode.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose(
[1, 0, 2]
)
fmha_out_decode = (
self.kv_b_proj_bmm(fmha_out_decode, proj_type="v")
.transpose([1, 0, 2])
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
)
if fmha_out is None:

File diff suppressed because it is too large Load Diff

View File

@@ -120,23 +120,28 @@ class MetaxWorker(WorkerBase):
before_run_meminfo_used = info.vramUse * 1024
before_run_meminfo_free = before_run_meminfo_total - before_run_meminfo_used
logger.info("Before running the profile, the memory usage info of Metax GPU is as follows:")
logger.info(f"Device Index: {device_id}")
logger.info(f"Device Total memory: {before_run_meminfo_total / Gb}")
logger.info(f"Device used memory: {before_run_meminfo_used / Gb}")
logger.info(f"Device free memory: {before_run_meminfo_free / Gb}")
logger.info(f"Paddle reserved memory: {paddle_reserved_mem_before_run / Gb}")
logger.info(f"Paddle allocated memory: {paddle_allocated_mem_before_run / Gb}")
logger.info(
(
"Before running the profile, the memory usage info is as follows:",
f"\nDevice Index: {device_id}",
f"\nDevice Total memory: {before_run_meminfo_total / Gb}",
f"\nDevice used memory: {before_run_meminfo_used / Gb}",
f"\nDevice free memory: {before_run_meminfo_free / Gb}",
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}",
)
)
# 2. Profile run
self.model_runner.profile_run()
set_random_seed(self.fd_config.model_config.seed)
# 3. Statistical memory information
paddle_reserved_mem_after_run = paddle.device.max_memory_reserved(local_rank)
paddle_allocated_mem_after_run = paddle.device.max_memory_allocated(local_rank)
model_block_memory_used = self.cal_theortical_kvcache()
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
paddle_peak_increase = paddle_allocated_mem_after_run - paddle_allocated_mem_before_run
paddle.device.empty_cache()
@@ -154,15 +159,19 @@ class MetaxWorker(WorkerBase):
end_time = time.perf_counter()
logger.info("After running the profile, the memory usage info of Metax GPU is as follows:")
logger.info(f"Device Index: {device_id}")
logger.info(f"Device Total memory: {after_run_meminfo_total / Gb}")
logger.info(f"Device used memory: {after_run_meminfo_used / Gb}")
logger.info(f"Device free memory: {after_run_meminfo_free / Gb}")
logger.info(f"Paddle reserved memory: {paddle_reserved_mem_after_run / Gb}")
logger.info(f"Paddle allocated memory: {paddle_allocated_mem_after_run / Gb}")
logger.info(f"Paddle available_kv_cache_memory: {available_kv_cache_memory / Gb}")
logger.info(f"Profile time: {end_time - start_time}")
logger.info(
(
"After running the profile, the memory usage info is as follows:",
f"\nDevice Index: {device_id}",
f"\nDevice Total memory: {after_run_meminfo_total / Gb}",
f"\nDevice used memory: {after_run_meminfo_used / Gb}",
f"\nDevice free memory: {after_run_meminfo_free / Gb}",
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
f"\nProfile time: {end_time - start_time}",
)
)
return available_kv_cache_memory