[BugFix]fix dp&ep&tp and muti node infer (#3629)

* rm log

* fix bug

* fix bug

* fix dp&ep&tp and muti node infer

* fix

---------

Co-authored-by: Yuanle Liu <yuanlehome@163.com>
This commit is contained in:
gaoziyuan
2025-08-28 19:09:10 +08:00
committed by GitHub
parent 17731a8acd
commit fc635acc47
7 changed files with 48 additions and 34 deletions

View File

@@ -325,28 +325,39 @@ def load_composite_checkpoint(
# 2. Tensor Parallel (TP)
# 3. Pre-sharded (pre-split)
"""
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
state_dict = load_pre_sharded_checkpoint(
model_path,
fd_config.parallel_config.tensor_parallel_rank,
use_fastsafetensor=False,
)
# (TODO: remove in the future)
if (
fd_config.parallel_config.use_ep
and fd_config.speculative_config.model_type != "mtp"
and fd_config.parallel_config.tensor_parallel_size == 1
):
state_dict = load_ep_checkpoint(model_path, fd_config, return_numpy=True)
else:
if fd_config.load_config.use_fastsafetensor and (current_platform.available() and current_platform.is_cuda()):
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
state_dict = load_tp_checkpoint(
rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))
]
if len(rank_dirs) > 1:
if fd_config.parallel_config.tensor_parallel_size != len(rank_dirs):
raise ValueError(f"Your model only supports loading with tp{len(rank_dirs)}")
state_dict = load_pre_sharded_checkpoint(
model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy,
fd_config.parallel_config.tensor_parallel_rank,
use_fastsafetensor=False,
)
else:
if fd_config.load_config.use_fastsafetensor and (
current_platform.available() and current_platform.is_cuda()
):
state_dict = load_tp_checkpoint_v1(model_path, cls, fd_config, use_fastsafetensor=True)
deal_state_dict(state_dict)
else:
# NOTE: for very big model, cpu will be out of memory
state_dict = load_tp_checkpoint(
model_path,
cls,
fd_config.model_config.pretrained_config,
return_numpy=return_numpy,
)
if not state_dict:
raise ValueError("weight not found in state_dict !")
return state_dict