mirror of
https://github.com/PaddlePaddle/FastDeploy.git
synced 2025-12-24 13:28:13 +08:00
[fix] fix ep group all-reduce
This commit is contained in:
@@ -350,9 +350,12 @@ class ParallelConfig:
|
||||
(self.data_parallel_rank + 1) * self.tensor_parallel_size,
|
||||
)
|
||||
)
|
||||
dist.collective._set_custom_gid(None)
|
||||
# same ep group id
|
||||
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
|
||||
self.ep_group = dist.new_group(range(self.expert_parallel_size))
|
||||
if self.enable_expert_parallel:
|
||||
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
|
||||
self.ep_group = dist.new_group(range(self.expert_parallel_size))
|
||||
dist.collective._set_custom_gid(None)
|
||||
logger.info(
|
||||
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
|
||||
)
|
||||
|
||||
@@ -298,7 +298,7 @@ class CutlassMoEMethod(UnquantizedFusedMoEMethod):
|
||||
)
|
||||
|
||||
if layer.reduce_results and layer.tp_size > 1:
|
||||
tensor_model_parallel_all_reduce(fused_moe_out)
|
||||
tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
|
||||
|
||||
return fused_moe_out
|
||||
|
||||
|
||||
Reference in New Issue
Block a user