diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index be0d76a33..6156fc479 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -60,7 +60,7 @@ def load_reordered_experts(model_path: str, key_name: str): return weight -def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool = False): +def load_ep_checkpoint(cls: PretrainedModel, model_path: str, fd_config: FDConfig, return_numpy: bool = False): """ load ep checkpoint """ @@ -138,6 +138,10 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool if k in weight_list: filtered_map[k] = weight_list[k] + if fd_config.parallel_config.tensor_parallel_size > 1: + tp_actions = cls._get_tensor_parallel_mappings(fd_config.model_config.pretrained_config) + new_actions = {k: v for k, v in tp_actions.items() if k not in num_local_ffn_keys} + state_dict = {} # Get all safetensor file paths that need to be opened safetensor_paths = set(filtered_map.values()) @@ -153,6 +157,9 @@ def load_ep_checkpoint(model_path: str, fd_config: FDConfig, return_numpy: bool for k in filtered_map: if filtered_map[k] == safetensor_path and k in f.keys(): weight = f.get_tensor(k) + if fd_config.parallel_config.tensor_parallel_size > 1: + if k in new_actions: + weight = new_actions[k](weight) if not return_numpy: weight = paddle.Tensor(weight, zero_copy=True) weight = weight._copy_to(paddle.framework._current_expected_place(), False) @@ -324,12 +331,8 @@ def load_composite_checkpoint( # 3. Pre-sharded (pre-split) """ # (TODO: remove in the future) - if ( - fd_config.parallel_config.use_ep - and fd_config.speculative_config.model_type != "mtp" - and fd_config.parallel_config.tensor_parallel_size == 1 - ): - state_dict = load_ep_checkpoint(model_path, fd_config, return_numpy=True) + if fd_config.parallel_config.use_ep and fd_config.speculative_config.model_type != "mtp": + state_dict = load_ep_checkpoint(cls, model_path, fd_config, return_numpy=True) else: rank_dirs = [ f for f in os.listdir(model_path) if f.startswith("rank") and os.path.isdir(os.path.join(model_path, f))