diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 9b4bfe840..2c7f9aef3 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -415,6 +415,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = None): model_format = getattr(param, "model_format", "") if model_format == "torch": + loaded_weight = get_tensor(loaded_weight) loaded_weight = loaded_weight.transpose([1, 0]) output_dim = getattr(param, "output_dim", None) assert output_dim is not None @@ -446,7 +447,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): shard_offset = self.local_rank * block_size shard_size = (self.local_rank + 1) * block_size loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size) - + loaded_weight = get_tensor(loaded_weight) if not param._is_initialized(): param.initialize() param_shard_size = output_size // 2 @@ -548,6 +549,7 @@ class QKVParallelLinear(ColumnParallelLinear): head_dim = param.shape[dim] // (self.num_heads_per_rank + 2 * self.kv_num_heads_per_rank) model_format = getattr(param, "model_format", "") if model_format == "torch": + loaded_weight = get_tensor(loaded_weight) loaded_weight = loaded_weight.transpose([1, 0]) if loaded_shard_id is None: # Loaded weight is already fused on disk @@ -568,12 +570,13 @@ class QKVParallelLinear(ColumnParallelLinear): # Tensor parallelism splits the weight along the output_dim if self.nranks != 1: block_size = self._get_shard_size_mapping(loaded_shard_id) - dim = -1 if output_dim else 0 shard_id = self.local_rank if loaded_shard_id == "q" else self.local_rank // self.num_kv_head_replicas shard_offset = shard_id * block_size shard_size = (shard_id + 1) * block_size loaded_weight = slice_fn(loaded_weight, output_dim, start=shard_offset, end=shard_size) + loaded_weight = get_tensor(loaded_weight) + if not param._is_initialized(): param.initialize() diff --git a/fastdeploy/model_executor/layers/moe/moe.py b/fastdeploy/model_executor/layers/moe/moe.py index c77379e68..bc58ef3eb 100644 --- a/fastdeploy/model_executor/layers/moe/moe.py +++ b/fastdeploy/model_executor/layers/moe/moe.py @@ -176,17 +176,24 @@ class FusedMoE(nn.Layer): if shard_id is None: # 1.gate up fused in disk + model_format = getattr(param, "model_format", "") + is_torch_model = model_format == "torch" output_size = param[expert_id - self.expert_id_offset].shape[SHARD_ID_TO_SHARDED_DIM["gate"]] - shard_offsets = [ - # (shard_id, shard_offset, shard_size) - ("gate", 0, output_size // 2 * self.tp_size), - ("up", output_size // 2 * self.tp_size, output_size // 2 * self.tp_size), - ] - for shard_id, shard_offset, shard_size in shard_offsets: - loaded_weight_shard = slice_fn( - loaded_weight, SHARD_ID_TO_SHARDED_DIM[shard_id], shard_offset, shard_offset + shard_size - ) - self.weight_loader(param, loaded_weight_shard, expert_id, shard_id) + per_rank = output_size // 2 + start = self.tp_rank * per_rank + loaded_weight_shard_gate = slice_fn( + loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["gate"], start, start + per_rank + ) + self._load_gate_up_weight( + param, expert_id, loaded_weight_shard_gate, "gate", SHARD_ID_TO_SHARDED_DIM["gate"], is_sharded=True + ) + start_up = output_size // 2 * self.tp_size + self.tp_rank * per_rank + loaded_weight_shard_up = slice_fn( + loaded_weight, is_torch_model ^ SHARD_ID_TO_SHARDED_DIM["up"], start_up, start_up + per_rank + ) + self._load_gate_up_weight( + param, expert_id, loaded_weight_shard_up, "up", SHARD_ID_TO_SHARDED_DIM["up"], is_sharded=True + ) else: # 2.gate up splited in disk assert shard_id in ["gate", "down", "up"] @@ -198,22 +205,23 @@ class FusedMoE(nn.Layer): shard_dim=SHARD_ID_TO_SHARDED_DIM[shard_id], ) - def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None): + def _load_gate_up_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None, is_sharded=False): model_format = getattr(param, "model_format", "") - if model_format == "torch": - loaded_weight = loaded_weight.transpose([1, 0]) - dim = -1 if shard_dim else 0 - if self.tp_size > 1: + is_torch_model = model_format == "torch" + if self.tp_size > 1 and not is_sharded: + tp_shard_dim = is_torch_model ^ shard_dim + weight_dim = -1 if tp_shard_dim else 0 if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)): - size = loaded_weight.shape[dim] + size = loaded_weight.shape[weight_dim] else: - size = loaded_weight.get_shape()[dim] + size = loaded_weight.get_shape()[weight_dim] block_size = size // self.tp_size shard_offset = self.tp_rank * block_size shard_size = (self.tp_rank + 1) * block_size - loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size) - + loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size) + loaded_weight = get_tensor(loaded_weight) expert_param = param[expert_id - self.expert_id_offset] + dim = -1 if shard_dim else 0 param_shard_size = expert_param.shape[dim] // 2 if shard_id == "gate": param_shard_offset = 0 @@ -232,9 +240,8 @@ class FusedMoE(nn.Layer): ) # To ensure compatibility across backends, apply an extra transpose for GCU and XPU - if current_platform.is_xpu() or current_platform.is_gcu(): - if expert_param.shape != loaded_weight.shape: - loaded_weight = loaded_weight.transpose([1, 0]) + if expert_param.shape != loaded_weight.shape: + loaded_weight = loaded_weight.transpose([1, 0]) assert expert_param.shape == loaded_weight.shape, ( f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})" ) @@ -242,26 +249,26 @@ class FusedMoE(nn.Layer): def _load_down_weight(self, param, expert_id, loaded_weight, shard_id, shard_dim=None): model_format = getattr(param, "model_format", "") - if model_format == "torch": - loaded_weight = loaded_weight.transpose([1, 0]) + is_torch_model = model_format == "torch" if self.tp_size > 1 and shard_dim is not None: - dim = -1 if shard_dim else 0 - if isinstance(loaded_weight, (np.ndarray, paddle.Tensor)): + tp_shard_dim = is_torch_model ^ shard_dim + dim = -1 if tp_shard_dim else 0 + if isinstance(loaded_weight, paddle.Tensor): size = loaded_weight.shape[dim] else: size = loaded_weight.get_shape()[dim] block_size = size // self.tp_size shard_offset = self.tp_rank * block_size shard_size = (self.tp_rank + 1) * block_size - loaded_weight = slice_fn(loaded_weight, shard_dim, shard_offset, shard_size) + loaded_weight = slice_fn(loaded_weight, tp_shard_dim, shard_offset, shard_size) + loaded_weight = get_tensor(loaded_weight) expert_param = param[expert_id - self.expert_id_offset] if hasattr(param, "tensor_track"): # for dyn quant param.tensor_track.mark(start=0, batch_id=expert_id - self.expert_id_offset) - # To ensure compatibility across backends, apply an extra transpose for GCU and XPU - if current_platform.is_xpu or current_platform.is_gcu(): - if expert_param.shape != loaded_weight.shape: - loaded_weight = loaded_weight.transpose([1, 0]) + # To ensure compatibility across backends, apply an extra transpose for GCU and XPU and opensource weight + if expert_param.shape != loaded_weight.shape: + loaded_weight = loaded_weight.transpose([1, 0]) assert expert_param.shape == loaded_weight.shape, ( f"Attempted to load weight ({loaded_weight.shape}) " f"into parameter ({expert_param.shape})" ) diff --git a/fastdeploy/model_executor/load_weight_utils.py b/fastdeploy/model_executor/load_weight_utils.py index 6e1097d86..be0d76a33 100644 --- a/fastdeploy/model_executor/load_weight_utils.py +++ b/fastdeploy/model_executor/load_weight_utils.py @@ -29,7 +29,6 @@ from safetensors import safe_open from tqdm import tqdm from fastdeploy.config import FDConfig -from fastdeploy.model_executor.layers.utils import get_tensor from fastdeploy.model_executor.models.tp_utils import ( check_tensor_parallel_prerequisites, ) @@ -186,8 +185,7 @@ def fast_weights_iterator(safe_tensor_list: list[str]): with fast_safe_open(st_file, framework="np") as f: for name in f.keys(): param_slice = f.get_slice(name) - paddle_tensor = get_tensor(param_slice) - yield name, paddle_tensor + yield name, param_slice def fastsafetensors_weights_iterator( diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index c95a45438..8e90fb80f 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -160,6 +160,7 @@ def default_weight_loader(fd_config: FDConfig) -> None: output_dim = getattr(param, "output_dim", None) model_format = getattr(param, "model_format", "") if model_format == "torch": + loaded_weight = get_tensor(loaded_weight) loaded_weight = loaded_weight.transpose([1, 0]) # Tensor parallelism splits the weight along the output_dim if output_dim is not None and fd_config.parallel_config.tensor_parallel_size > 1: