[Metax] support ENABLE_V1_KVCACHE_SCHEDULER (#5163)

This commit is contained in:
xiaozude
2025-11-24 19:19:49 +08:00
committed by GitHub
parent e150a418d4
commit d5bd64336a
5 changed files with 24 additions and 14 deletions

View File

@@ -523,7 +523,7 @@ class EngineArgs:
f"= {expected_ports}, but got {len(self.rdma_comm_ports)}."
)
if not current_platform.is_cuda() and not current_platform.is_xpu():
if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()):
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if self.guided_decoding_backend != "off":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0

View File

@@ -141,8 +141,11 @@ class MetaxMLAAttentionBackend(AttentionBackend):
self.flash_attn_func = flash_attn_unpadded_func
self.flash_attn_kwargs = {"softmax_scale": self.attn_softmax_scale}
@paddle.no_grad()
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""
paddle.device.empty_cache()
metadata = MLAAttentionMetadata()
metadata.max_partition_size = 32768
metadata.encoder_max_partition_size = self.max_seq_len
@@ -203,8 +206,6 @@ class MetaxMLAAttentionBackend(AttentionBackend):
self.seq_lens = seq_lens_decoder + seq_lens_this_time
self.block_tables = forward_meta.block_tables[non_zero_index]
paddle.device.empty_cache()
def get_attntion_meta(self) -> AttentionMetadata:
"""get_attntion_meta"""
return self.attention_metadata
@@ -290,6 +291,8 @@ class MetaxMLAAttentionBackend(AttentionBackend):
"""
Prefill阶段的前向传播
"""
paddle.device.empty_cache()
metadata = self.attention_metadata
latent_cache = forward_meta.caches[layer.layer_id] if hasattr(forward_meta, "caches") else None
@@ -364,6 +367,7 @@ class MetaxMLAAttentionBackend(AttentionBackend):
return fmha_out
@paddle.no_grad()
def forward_mixed(
self,
q: paddle.Tensor,

View File

@@ -103,12 +103,12 @@ class MetaxWorker(WorkerBase):
Gb = 1024**3
local_rank = self.local_rank % self.max_chips_per_node
paddle.device.cuda.reset_max_memory_reserved(local_rank)
paddle.device.cuda.reset_max_memory_allocated(local_rank)
paddle.device.reset_max_memory_reserved(local_rank)
paddle.device.reset_max_memory_allocated(local_rank)
# max memory for Allocator
paddle_reserved_mem_before_run = paddle.device.cuda.max_memory_reserved(local_rank)
paddle_reserved_mem_before_run = paddle.device.max_memory_reserved(local_rank)
# max memory for Tensor
paddle_allocated_mem_before_run = paddle.device.cuda.max_memory_allocated(local_rank) # not reserved
paddle_allocated_mem_before_run = paddle.device.max_memory_allocated(local_rank) # not reserved
device_id = int(self.device_ids[local_rank])
if os.getenv("MACA_VISIBLE_DEVICES") is not None:
@@ -132,13 +132,13 @@ class MetaxWorker(WorkerBase):
self.model_runner.profile_run()
# 3. Statistical memory information
paddle_reserved_mem_after_run = paddle.device.cuda.max_memory_reserved(local_rank)
paddle_allocated_mem_after_run = paddle.device.cuda.max_memory_allocated(local_rank)
paddle_reserved_mem_after_run = paddle.device.max_memory_reserved(local_rank)
paddle_allocated_mem_after_run = paddle.device.max_memory_allocated(local_rank)
model_block_memory_used = self.cal_theortical_kvcache()
paddle_peak_increase = paddle_reserved_mem_after_run - paddle_allocated_mem_before_run
paddle.device.cuda.empty_cache()
paddle.device.empty_cache()
info = pymxsml.mxSmlGetMemoryInfo(device_id)
after_run_meminfo_total = info.vramTotal * 1024
@@ -146,8 +146,10 @@ class MetaxWorker(WorkerBase):
after_run_meminfo_free = after_run_meminfo_total - after_run_meminfo_used
available_kv_cache_memory = (
after_run_meminfo_free - paddle_peak_increase
) * self.cache_config.gpu_memory_utilization
after_run_meminfo_total * self.cache_config.gpu_memory_utilization
- after_run_meminfo_used
- paddle_peak_increase
)
available_kv_cache_memory += model_block_memory_used * self.cache_config.total_block_num
end_time = time.perf_counter()

View File

@@ -937,7 +937,7 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig:
logger.info(f"- Dynamic load weight: {load_config.dynamic_load_weight}")
logger.info(f"- Load strategy: {load_config.load_strategy}")
if not current_platform.is_cuda() and not current_platform.is_xpu():
if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()):
logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.")
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
if structured_outputs_config.guided_decoding_backend != "off":

View File

@@ -10,7 +10,7 @@ tqdm
pynvml
uvicorn==0.29.0
fastapi
paddleformers>=0.2
paddleformers==0.3.2
redis
etcd3
httpx
@@ -42,3 +42,7 @@ opentelemetry-exporter-otlp
opentelemetry-instrumentation-fastapi
opentelemetry-instrumentation-logging
partial_json_parser
msgspec
einops
setproctitle
aistudio_sdk