Revert "[Bug fix] fix attention rank init (#2743)" (#2761)

This reverts commit e8bbe7244b.
This commit is contained in:
RichardWooSJTU
2025-07-09 10:38:12 +08:00
committed by GitHub
parent f72c4de539
commit 6610aa29d0
4 changed files with 13 additions and 17 deletions

View File

@@ -109,7 +109,7 @@ class MLAAttentionBackend(AttentionBackend):
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.expert_parallel_rank * fd_config.parallel_config.tensor_parallel_degree + fd_config.parallel_config.tensor_parallel_rank
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
@@ -135,13 +135,10 @@ class MLAAttentionBackend(AttentionBackend):
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
if self.device_id is None:
self.device_id = self.rank
else:
device_ids = self.device_id.split(",")
rank_index = self.rank % len(device_ids)
self.device_id = self.device_id[rank_index]
self.device_id = self.device_id.split(",")[self.rank]
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""