[Bug fix] fix attention rank init (#2743)

* fix attention rank init

* fix attention rank init
This commit is contained in:
RichardWooSJTU
2025-07-08 17:19:49 +08:00
committed by GitHub
parent 57b086dc6b
commit e8bbe7244b
4 changed files with 17 additions and 13 deletions

View File

@@ -91,7 +91,7 @@ class AppendAttentionBackend(AttentionBackend):
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
self.rank: int = fd_config.parallel_config.expert_parallel_rank * fd_config.parallel_config.tensor_parallel_degree + fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
@@ -108,12 +108,12 @@ class AppendAttentionBackend(AttentionBackend):
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
fd_config.parallel_config.expert_parallel_rank
if self.device_id is None:
self.device_id = device_id
self.device_id = self.rank
else:
self.device_id = self.device_id.split(",")[device_id]
device_ids = self.device_id.split(",")
rank_index = self.rank % len(device_ids)
self.device_id = self.device_id[rank_index]
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""

View File

@@ -100,7 +100,7 @@ class FlashAttentionBackend(AttentionBackend):
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
self.rank: int = fd_config.parallel_config.expert_parallel_rank * fd_config.parallel_config.tensor_parallel_degree + fd_config.parallel_config.tensor_parallel_rank
# pd_disaggregation
self.use_pd_disaggregation: int = int(
@@ -110,12 +110,13 @@ class FlashAttentionBackend(AttentionBackend):
if fd_config.parallel_config.expert_parallel_rank is None:
fd_config.parallel_config.expert_parallel_rank = 0
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
fd_config.parallel_config.expert_parallel_rank
if self.device_id is None:
self.device_id = device_id
self.device_id = self.rank
else:
self.device_id = self.device_id.split(",")[device_id]
device_ids = self.device_id.split(",")
rank_index = self.rank % len(device_ids)
self.device_id = self.device_id[rank_index]
def get_attntion_meta(self):
"""get_attntion_meta"""

View File

@@ -109,7 +109,7 @@ class MLAAttentionBackend(AttentionBackend):
self.use_speculate: bool = self.speculative_method is not None
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
self.rank: int = fd_config.parallel_config.expert_parallel_rank * fd_config.parallel_config.tensor_parallel_degree + fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads
@@ -135,10 +135,13 @@ class MLAAttentionBackend(AttentionBackend):
os.getenv("FLAGS_use_pd_disaggregation", 0))
self.start_layer_index: int = fd_config.model_config.start_layer_index
self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None)
if self.device_id is None:
self.device_id = self.rank
else:
self.device_id = self.device_id.split(",")[self.rank]
device_ids = self.device_id.split(",")
rank_index = self.rank % len(device_ids)
self.device_id = self.device_id[rank_index]
def init_attention_metadata(self, forward_meta: ForwardMeta):
"""Initialize attention metadata hence all layers in the forward pass can reuse it."""

View File

@@ -91,7 +91,7 @@ class XPUAttentionBackend(AttentionBackend):
# self.use_speculate = self.speculate_method is not None
# self.speculate_max_draft_token_num = fd_config.parallel_config.speculate_max_draft_tokens
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
self.rank: int = fd_config.parallel_config.expert_parallel_rank * fd_config.parallel_config.tensor_parallel_degree + fd_config.parallel_config.tensor_parallel_rank
self.kv_num_heads: int = kv_num_heads
self.num_heads: int = num_heads