Supports DP+TP+EP hybrid parallel deployment strategy (#3489)

* Support DP+TP+EP hybrid parallel deployment strategy

* Support DP+TP+EP hybrid parallel deployment strategy

* fix conflict

* add moe_tp_ep function split_allgather_out

* del tp_group in moe_cutlass_backend

* for ci

* fix parallel_config for ci

* del log
This commit is contained in:
lzy
2025-08-26 15:04:01 +08:00
committed by GitHub
parent 52eda7fdb3
commit d339df2e90
15 changed files with 304 additions and 224 deletions

View File

@@ -321,33 +321,28 @@ def load_composite_checkpoint(
# 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split)
"""
if fd_config.parallel_config.use_ep and fd_config.speculative_config.model_type != "mtp":
state_dict = load_ep_checkpoint(model_path, fd_config, return_numpy=True)
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
use_fastsafetensor=False,
)
else:
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
use_fastsafetensor=False,
)
if fd_config.load_config.use_fastsafetensor and (current_platform.available() and current_platform.is_cuda()):
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
if fd_config.load_config.use_fastsafetensor and (
current_platform.available() and current_platform.is_cuda()
):
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
state_dict = load_tp_checkpoint(
model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy,
)
state_dict = load_tp_checkpoint(
model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy,
)
if not state_dict:
raise ValueError("weight not found in state_dict !")
return state_dict