mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[Metax] optimize mla attention (#5258)
This commit is contained in:
3
build.sh
3
build.sh
@@ -162,12 +162,11 @@ function copy_ops(){
|
|||||||
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
|
is_maca=`$python -c "import paddle; print(paddle.device.is_compiled_with_custom_device('metax_gpu'))"`
|
||||||
if [ "$is_maca" = "True" ]; then
|
if [ "$is_maca" = "True" ]; then
|
||||||
DEVICE_TYPE="metax_gpu"
|
DEVICE_TYPE="metax_gpu"
|
||||||
mkdir -p ../fastdeploy/model_executor/ops/base
|
|
||||||
cp -r ${OPS_TMP_DIR_BASE}/${WHEEL_BASE_NAME}/* ../fastdeploy/model_executor/ops/base
|
|
||||||
cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/gpu
|
cp -r ${TMP_PACKAGE_DIR}/* ../fastdeploy/model_executor/ops/gpu
|
||||||
echo -e "MACA ops have been copy to fastdeploy"
|
echo -e "MACA ops have been copy to fastdeploy"
|
||||||
return
|
return
|
||||||
fi
|
fi
|
||||||
|
|
||||||
is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"`
|
is_intel_hpu=`$python -c "import paddle; print(paddle.is_compiled_with_custom_device('intel_hpu'))"`
|
||||||
if [ "$is_intel_hpu" = "True" ]; then
|
if [ "$is_intel_hpu" = "True" ]; then
|
||||||
DEVICE_TYPE="intel-hpu"
|
DEVICE_TYPE="intel-hpu"
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from fastdeploy.model_executor.layers.attention.base_attention_backend import (
|
|||||||
AttentionMetadata,
|
AttentionMetadata,
|
||||||
)
|
)
|
||||||
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
|
||||||
|
from fastdeploy.platforms import current_platform
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -87,7 +88,10 @@ def allocate_launch_related_buffer(
|
|||||||
res = {}
|
res = {}
|
||||||
res["decoder_batch_ids"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
|
res["decoder_batch_ids"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
|
||||||
res["decoder_tile_ids_per_batch"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
|
res["decoder_tile_ids_per_batch"] = paddle.full([decode_max_tile_size], 0, dtype="int32")
|
||||||
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
if current_platform.is_maca():
|
||||||
|
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").cpu()
|
||||||
|
else:
|
||||||
|
res["decoder_num_blocks_cpu"] = paddle.full([1], 0, dtype="int32").pin_memory()
|
||||||
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
|
# NOTE: (changwenbin) MLA kernel only needs decoder_num_blocks_device in place of GPU tensor,
|
||||||
# adapted to cudagraph.
|
# adapted to cudagraph.
|
||||||
res["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
|
res["decoder_num_blocks_device"] = paddle.full([1], 0, dtype="int32")
|
||||||
|
|||||||
@@ -206,6 +206,9 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
|||||||
self.seq_lens = seq_lens_decoder + seq_lens_this_time
|
self.seq_lens = seq_lens_decoder + seq_lens_this_time
|
||||||
self.block_tables = forward_meta.block_tables[non_zero_index]
|
self.block_tables = forward_meta.block_tables[non_zero_index]
|
||||||
|
|
||||||
|
self.tile_scheduler_metadata = None
|
||||||
|
self.num_splits = None
|
||||||
|
|
||||||
def get_attntion_meta(self) -> AttentionMetadata:
|
def get_attntion_meta(self) -> AttentionMetadata:
|
||||||
"""get_attntion_meta"""
|
"""get_attntion_meta"""
|
||||||
return self.attention_metadata
|
return self.attention_metadata
|
||||||
@@ -250,13 +253,13 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
query = query.reshape([-1, seq_len_q, num_heads_q, head_dim_qk])
|
query = query.reshape_([-1, seq_len_q, num_heads_q, head_dim_qk])
|
||||||
|
|
||||||
tile_scheduler_metadata, num_splits = get_mla_metadata(
|
if self.tile_scheduler_metadata is None or self.num_splits is None:
|
||||||
self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
|
self.tile_scheduler_metadata, self.num_splits = get_mla_metadata(
|
||||||
)
|
self.seq_lens, seq_len_q * num_heads_q // num_heads_kv, num_heads_kv
|
||||||
|
)
|
||||||
assert tile_scheduler_metadata.shape[0] != 0
|
assert self.tile_scheduler_metadata.shape[0] != 0
|
||||||
|
|
||||||
out = flash_mla_with_kvcache(
|
out = flash_mla_with_kvcache(
|
||||||
query,
|
query,
|
||||||
@@ -264,8 +267,8 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
|||||||
self.block_tables,
|
self.block_tables,
|
||||||
self.seq_lens,
|
self.seq_lens,
|
||||||
head_dim_v,
|
head_dim_v,
|
||||||
tile_scheduler_metadata,
|
self.tile_scheduler_metadata,
|
||||||
num_splits,
|
self.num_splits,
|
||||||
softmax_scale=self.attn_softmax_scale,
|
softmax_scale=self.attn_softmax_scale,
|
||||||
causal=self.causal,
|
causal=self.causal,
|
||||||
)[0]
|
)[0]
|
||||||
@@ -273,7 +276,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
|
|||||||
if seq_len_q != self.seq_lens_this_time_min:
|
if seq_len_q != self.seq_lens_this_time_min:
|
||||||
out = paddle.concat([paddle.split(x, [n, seq_len_q - n])[0] for x, n in zip(out, self.seq_lens_this_time)])
|
out = paddle.concat([paddle.split(x, [n, seq_len_q - n])[0] for x, n in zip(out, self.seq_lens_this_time)])
|
||||||
else:
|
else:
|
||||||
out = out.reshape([-1, num_heads_q, head_dim_v])
|
out = out.reshape_([-1, num_heads_q, head_dim_v])
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|||||||
@@ -411,14 +411,14 @@ class DeepseekV3MLAAttention(nn.Layer):
|
|||||||
forward_meta=forward_meta,
|
forward_meta=forward_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
fmha_out_decode = fmha_out_decode.reshape([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose(
|
fmha_out_decode = fmha_out_decode.reshape_([-1, self.num_attention_heads_tp, self.kv_lora_rank]).transpose(
|
||||||
[1, 0, 2]
|
[1, 0, 2]
|
||||||
)
|
)
|
||||||
|
|
||||||
fmha_out_decode = (
|
fmha_out_decode = (
|
||||||
self.kv_b_proj_bmm(fmha_out_decode, proj_type="v")
|
self.kv_b_proj_bmm(fmha_out_decode, proj_type="v")
|
||||||
.transpose([1, 0, 2])
|
.transpose([1, 0, 2])
|
||||||
.reshape([-1, self.num_attention_heads_tp * self.v_head_dim])
|
.reshape_([-1, self.num_attention_heads_tp * self.v_head_dim])
|
||||||
)
|
)
|
||||||
|
|
||||||
if fmha_out is None:
|
if fmha_out is None:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -120,23 +120,28 @@ class MetaxWorker(WorkerBase):
|
|||||||
before_run_meminfo_used = info.vramUse * 1024
|
before_run_meminfo_used = info.vramUse * 1024
|
||||||
before_run_meminfo_free = before_run_meminfo_total - before_run_meminfo_used
|
before_run_meminfo_free = before_run_meminfo_total - before_run_meminfo_used
|
||||||
|
|
||||||
logger.info("Before running the profile, the memory usage info of Metax GPU is as follows:")
|
logger.info(
|
||||||
logger.info(f"Device Index: {device_id}")
|
(
|
||||||
logger.info(f"Device Total memory: {before_run_meminfo_total / Gb}")
|
"Before running the profile, the memory usage info is as follows:",
|
||||||
logger.info(f"Device used memory: {before_run_meminfo_used / Gb}")
|
f"\nDevice Index: {device_id}",
|
||||||
logger.info(f"Device free memory: {before_run_meminfo_free / Gb}")
|
f"\nDevice Total memory: {before_run_meminfo_total / Gb}",
|
||||||
logger.info(f"Paddle reserved memory: {paddle_reserved_mem_before_run / Gb}")
|
f"\nDevice used memory: {before_run_meminfo_used / Gb}",
|
||||||
logger.info(f"Paddle allocated memory: {paddle_allocated_mem_before_run / Gb}")
|
f"\nDevice free memory: {before_run_meminfo_free / Gb}",
|
||||||
|
f"\nPaddle reserved memory: {paddle_reserved_mem_before_run / Gb}",
|
||||||
|
f"\nPaddle allocated memory: {paddle_allocated_mem_before_run / Gb}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# 2. Profile run
|
# 2. Profile run
|
||||||
self.model_runner.profile_run()
|
self.model_runner.profile_run()
|
||||||
|
set_random_seed(self.fd_config.model_config.seed)
|
||||||
|
|
||||||
# 3. Statistical memory information
|
# 3. Statistical memory information
|
||||||
paddle_reserved_mem_after_run = paddle.device.max_memory_reserved(local_rank)
|
paddle_reserved_mem_after_run = paddle.device.max_memory_reserved(local_rank)
|
||||||
paddle_allocated_mem_after_run = paddle.device.max_memory_allocated(local_rank)
|
paddle_allocated_mem_after_run = paddle.device.max_memory_allocated(local_rank)
|
||||||
|
|
||||||
model_block_memory_used = self.cal_theortical_kvcache()
|
model_block_memory_used = self.cal_theortical_kvcache()
|
||||||
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
|
paddle_peak_increase = paddle_allocated_mem_after_run - paddle_allocated_mem_before_run
|
||||||
|
|
||||||
paddle.device.empty_cache()
|
paddle.device.empty_cache()
|
||||||
|
|
||||||
@@ -154,15 +159,19 @@ class MetaxWorker(WorkerBase):
|
|||||||
|
|
||||||
end_time = time.perf_counter()
|
end_time = time.perf_counter()
|
||||||
|
|
||||||
logger.info("After running the profile, the memory usage info of Metax GPU is as follows:")
|
logger.info(
|
||||||
logger.info(f"Device Index: {device_id}")
|
(
|
||||||
logger.info(f"Device Total memory: {after_run_meminfo_total / Gb}")
|
"After running the profile, the memory usage info is as follows:",
|
||||||
logger.info(f"Device used memory: {after_run_meminfo_used / Gb}")
|
f"\nDevice Index: {device_id}",
|
||||||
logger.info(f"Device free memory: {after_run_meminfo_free / Gb}")
|
f"\nDevice Total memory: {after_run_meminfo_total / Gb}",
|
||||||
logger.info(f"Paddle reserved memory: {paddle_reserved_mem_after_run / Gb}")
|
f"\nDevice used memory: {after_run_meminfo_used / Gb}",
|
||||||
logger.info(f"Paddle allocated memory: {paddle_allocated_mem_after_run / Gb}")
|
f"\nDevice free memory: {after_run_meminfo_free / Gb}",
|
||||||
logger.info(f"Paddle available_kv_cache_memory: {available_kv_cache_memory / Gb}")
|
f"\nPaddle reserved memory: {paddle_reserved_mem_after_run / Gb}",
|
||||||
logger.info(f"Profile time: {end_time - start_time}")
|
f"\nPaddle allocated memory: {paddle_allocated_mem_after_run / Gb}",
|
||||||
|
f"\nAvailable KV Cache meomory: {available_kv_cache_memory / Gb}",
|
||||||
|
f"\nProfile time: {end_time - start_time}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return available_kv_cache_memory
|
return available_kv_cache_memory
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user