mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-10-06 00:57:33 +08:00
[Bug fix] fix attention rank init (#2743)
* fix attention rank init * fix attention rank init
This commit is contained in:
@@ -100,7 +100,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.use_speculate = self.speculative_method is not None
|
||||
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
|
||||
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
|
||||
self.rank: int = fd_config.parallel_config.tensor_parallel_rank
|
||||
self.rank: int = fd_config.parallel_config.expert_parallel_rank * fd_config.parallel_config.tensor_parallel_degree + fd_config.parallel_config.tensor_parallel_rank
|
||||
|
||||
# pd_disaggregation
|
||||
self.use_pd_disaggregation: int = int(
|
||||
@@ -110,12 +110,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
if fd_config.parallel_config.expert_parallel_rank is None:
|
||||
fd_config.parallel_config.expert_parallel_rank = 0
|
||||
device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \
|
||||
fd_config.parallel_config.expert_parallel_rank
|
||||
|
||||
if self.device_id is None:
|
||||
self.device_id = device_id
|
||||
self.device_id = self.rank
|
||||
else:
|
||||
self.device_id = self.device_id.split(",")[device_id]
|
||||
device_ids = self.device_id.split(",")
|
||||
rank_index = self.rank % len(device_ids)
|
||||
self.device_id = self.device_id[rank_index]
|
||||
|
||||
def get_attntion_meta(self):
|
||||
"""get_attntion_meta"""
|
||||
|
Reference in New Issue
Block a user