support tpep weight load (#3882)

This commit is contained in:
lizhenyun01
2025-09-05 13:56:29 +08:00
committed by GitHub
parent 55ebe855c0
commit 199f88ce1e

View File

@@ -60,7 +60,7 @@ def load_reordered_experts(model_path: str, key_name: str):
return weight return weight
def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool = False): def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfig, return_numpy: bool = False):
""" """
load ep checkpoint load ep checkpoint
""" """
@@ -138,6 +138,10 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
if k in weight_list: if k in weight_list:
filtered_map[k] = weight_list[k] filtered_map[k] = weight_list[k]
if fd_config.parallel_config.tensor_parallel_size > 1:
tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config)
new_actions = {k: v for k, v in tp_actions.items() if k not in num_local_ffn_keys}
state_dict = {} state_dict = {}
# Get all safetensor file paths that need to be opened # Get all safetensor file paths that need to be opened
safetensor_paths = set(filtered_map.values()) safetensor_paths = set(filtered_map.values())
@@ -153,6 +157,9 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool
for k in filtered_map: for k in filtered_map:
if filtered_map[k] == safetensor_path and k in f.keys(): if filtered_map[k] == safetensor_path and k in f.keys():
weight = f.get_tensor(k) weight = f.get_tensor(k)
if fd_config.parallel_config.tensor_parallel_size > 1:
if k in new_actions:
weight = new_actions[k](weight)
if not return_numpy: if not return_numpy:
weight = paddle.Tensor(weight, zero_copy=True) weight = paddle.Tensor(weight, zero_copy=True)
weight = weight._copy_to(paddle.framework._current_expected_place(), False) weight = weight._copy_to(paddle.framework._current_expected_place(), False)
@@ -324,12 +331,8 @@ def load_composite_checkpoint(
# 3. Pre-sharded (pre-split) # 3. Pre-sharded (pre-split)
""" """
# (TODO: remove in the future) # (TODO: remove in the future)
if ( if fd_config.parallel_config.use_ep and fd_config.speculative_config.model_type != "mtp":
fd_config.parallel_config.use_ep state_dict = load_ep_checkpoint(cls, model_path, fd_config, return_numpy=True)
and fd_config.speculative_config.model_type != "mtp"
and fd_config.parallel_config.tensor_parallel_size == 1
):
state_dict = load_ep_checkpoint(model_path, fd_config, return_numpy=True)
else: else:
rank_dirs = [ rank_dirs = [
f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f)) f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))