mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] optimize mla attention (#5258)
This commit is contained in:
3
build.sh
3
build.sh
@@ -162,12 +162,11 @@ function copy_ops(){
|
||||
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
|
||||
if [ "$is_maca" = "True" ]; then
|
||||
DEVICE_TYPE="metax_gpu"
|
||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
||||
cp -r ${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
||||
cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/gpu
|
||||
echo -e "MACA ops have been copy to fastdeploy"
|
||||
return
|
||||
fi
|
||||
|
||||
is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"`
|
||||
if [ "$is_intel_hpu" = "True" ]; then
|
||||
DEVICE_TYPE="intel-hpu"
|
||||
|
||||
@@ -43,6 +43,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
||||
AttentionMetadata,
|
||||
)
|
||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||
from fastdeploy.platforms import current_platform
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -87,7 +88,10 @@ def allocate_launch_related_buffer(
|
||||
res = {}
|
||||
res["decoder_batch_ids"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
|
||||
res["decoder_tile_ids_per_batch"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
|
||||
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
if current_platform.is_maca():
|
||||
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||
else:
|
||||
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
|
||||
# adapted to cudagraph.
|
||||
res["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
|
||||
|
||||
@@ -206,6 +206,9 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
self.seq_lens = seq_lens_decoder + seq_lens_this_time
|
||||
self.block_tables = forward_meta.block_tables[non_zero_index]
|
||||
|
||||
self.tile_scheduler_metadata = None
|
||||
self.num_splits = None
|
||||
|
||||
def get_attntion_meta(self) -> AttentionMetadata:
|
||||
"""get_attntion_meta"""
|
||||
return self.attention_metadata
|
||||
@@ -250,13 +253,13 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
]
|
||||
)
|
||||
|
||||
query = query.reshape([-1, seq_len_q, num_heads_q, head_dim_qk])
|
||||
query = query.reshape_([-1, seq_len_q, num_heads_q, head_dim_qk])
|
||||
|
||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
||||
self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
|
||||
)
|
||||
|
||||
assert tile_scheduler_metadata.shape[0] != 0
|
||||
if self.tile_scheduler_metadata is None or self.num_splits is None:
|
||||
self.tile_scheduler_metadata, self.num_splits = get_mla_metadata(
|
||||
self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
|
||||
)
|
||||
assert self.tile_scheduler_metadata.shape[0] != 0
|
||||
|
||||
out = flash_mla_with_kvcache(
|
||||
query,
|
||||
@@ -264,8 +267,8 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
self.block_tables,
|
||||
self.seq_lens,
|
||||
head_dim_v,
|
||||
tile_scheduler_metadata,
|
||||
num_splits,
|
||||
self.tile_scheduler_metadata,
|
||||
self.num_splits,
|
||||
softmax_scale=self.attn_softmax_scale,
|
||||
causal=self.causal,
|
||||
)[0]
|
||||
@@ -273,7 +276,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
||||
if seq_len_q != self.seq_lens_this_time_min:
|
||||
out = paddle.concat([paddle.split(x, [n, seq_len_q - n])[0] for x, n in zip(out, self.seq_lens_this_time)])
|
||||
else:
|
||||
out = out.reshape([-1, num_heads_q, head_dim_v])
|
||||
out = out.reshape_([-1, num_heads_q, head_dim_v])
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -411,14 +411,14 @@ class DeepseekV3MLAAttention(nn.Layer):
|
||||
forward_meta=forward_meta,
|
||||
)
|
||||
|
||||
fmha_out_decode = fmha_out_decode.reshape([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose(
|
||||
fmha_out_decode = fmha_out_decode.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose(
|
||||
[1, 0, 2]
|
||||
)
|
||||
|
||||
fmha_out_decode = (
|
||||
self.kv_b_proj_bmm(fmha_out_decode, proj_type="v")
|
||||
.transpose([1, 0, 2])
|
||||
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||
.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||
)
|
||||
|
||||
if fmha_out is None:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -120,23 +120,28 @@ class MetaxWorker(WorkerBase):
|
||||
before_run_meminfo_used = info.vramUse * 1024
|
||||
before_run_meminfo_free = before_run_meminfo_total - before_run_meminfo_used
|
||||
|
||||
logger.info("Before running the profile, the memory usage info of Metax GPU is as follows:")
|
||||
logger.info(f"Device Index: {device_id}")
|
||||
logger.info(f"Device Total memory: {before_run_meminfo_total / Gb}")
|
||||
logger.info(f"Device used memory: {before_run_meminfo_used / Gb}")
|
||||
logger.info(f"Device free memory: {before_run_meminfo_free / Gb}")
|
||||
logger.info(f"Paddle reserved memory: {paddle_reserved_mem_before_run / Gb}")
|
||||
logger.info(f"Paddle allocated memory: {paddle_allocated_mem_before_run / Gb}")
|
||||
logger.info(
|
||||
(
|
||||
"Before running the profile, the memory usage info is as follows:",
|
||||
f"\nDevice Index: {device_id}",
|
||||
f"\nDevice Total memory: {before_run_meminfo_total / Gb}",
|
||||
f"\nDevice used memory: {before_run_meminfo_used / Gb}",
|
||||
f"\nDevice free memory: {before_run_meminfo_free / Gb}",
|
||||
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
|
||||
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}",
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Profile run
|
||||
self.model_runner.profile_run()
|
||||
set_random_seed(self.fd_config.model_config.seed)
|
||||
|
||||
# 3. Statistical memory information
|
||||
paddle_reserved_mem_after_run = paddle.device.max_memory_reserved(local_rank)
|
||||
paddle_allocated_mem_after_run = paddle.device.max_memory_allocated(local_rank)
|
||||
|
||||
model_block_memory_used = self.cal_theortical_kvcache()
|
||||
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
|
||||
paddle_peak_increase = paddle_allocated_mem_after_run - paddle_allocated_mem_before_run
|
||||
|
||||
paddle.device.empty_cache()
|
||||
|
||||
@@ -154,15 +159,19 @@ class MetaxWorker(WorkerBase):
|
||||
|
||||
end_time = time.perf_counter()
|
||||
|
||||
logger.info("After running the profile, the memory usage info of Metax GPU is as follows:")
|
||||
logger.info(f"Device Index: {device_id}")
|
||||
logger.info(f"Device Total memory: {after_run_meminfo_total / Gb}")
|
||||
logger.info(f"Device used memory: {after_run_meminfo_used / Gb}")
|
||||
logger.info(f"Device free memory: {after_run_meminfo_free / Gb}")
|
||||
logger.info(f"Paddle reserved memory: {paddle_reserved_mem_after_run / Gb}")
|
||||
logger.info(f"Paddle allocated memory: {paddle_allocated_mem_after_run / Gb}")
|
||||
logger.info(f"Paddle available_kv_cache_memory: {available_kv_cache_memory / Gb}")
|
||||
logger.info(f"Profile time: {end_time - start_time}")
|
||||
logger.info(
|
||||
(
|
||||
"After running the profile, the memory usage info is as follows:",
|
||||
f"\nDevice Index: {device_id}",
|
||||
f"\nDevice Total memory: {after_run_meminfo_total / Gb}",
|
||||
f"\nDevice used memory: {after_run_meminfo_used / Gb}",
|
||||
f"\nDevice free memory: {after_run_meminfo_free / Gb}",
|
||||
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
|
||||
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
|
||||
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
|
||||
f"\nProfile time: {end_time - start_time}",
|
||||
)
|
||||
)
|
||||
|
||||
return available_kv_cache_memory
|
||||
|
||||
|
||||
Reference in New Issue
Block a user